Compare commits

..

133 Commits

Author SHA1 Message Date
21b271d088 Automated submodule update: FBGEMM 2025-10-31 13:24:08 -07:00
8209a0506b [Pytorch] Enable aarch64 convert autovec only on clang (#166739)
Summary: We've noted issues with modern GCC versions. Until further investigation is carried, we'll leave the code only enabled on clang

Test Plan: CI

Differential Revision: D85968395

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166739
Approved by: https://github.com/mcfi, https://github.com/Skylion007, https://github.com/robert-hardwick
2025-10-31 20:22:33 +00:00
70aeb49198 [dynamo] clarify graph break handling/logging in symbolic_convert (#166587)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166587
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #166476, #166477, #166586
2025-10-31 20:13:16 +00:00
cf9a834f39 [BE] Move GreenContext implementation details to cpp (#166462)
- Remove all complex defines logic from the header
- Make GreenContext constructor private, as  it should only be created via the static method as singleton
- Delete unused `getContext` and `getGreenContext` methods
- Rename `CUDA_HAS_GREEN_CONTEXT` to `HAS_CUDA_GREEN_CONTEXT()`, which results in compilation error if one accidentally makes a typo
- Suppress `-Wunused-private-field` is GreenContext is not available
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166462
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-10-31 20:11:02 +00:00
856a7a5298 Add missing device to namedtensor tests (#166717)
This PR passes unused `device` argument to tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166717
Approved by: https://github.com/Skylion007
2025-10-31 20:04:41 +00:00
ef8d97efcf fix broken nn_convolution test (#166666)
Summary: Broken by oss diff during oncall by third party contributor

Test Plan: buck test 'fbcode//mode/dev-nosan' fbcode//caffe2/test:nn_convolution -- --run-disabled

Differential Revision: D85899891

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166666
Approved by: https://github.com/atalman, https://github.com/seemethere, https://github.com/Skylion007
2025-10-31 19:59:50 +00:00
d2be06f673 [cpu][fix] Update ACL version to fix crashes with tensor sizes > 2^31-1 (#165904)
----

- Updates Arm Compute Library (ACL) to v52.6.0
- v52.6.0 contains https://github.com/ARM-software/ComputeLibrary/pull/1201 which fixes crashes with tensors of sizes > 2^31-1

fixes: #165654

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165904
Approved by: https://github.com/malfet
2025-10-31 19:37:26 +00:00
08f4535378 Refactor AOTAutogradCacheEntry into AOTAutogradResult (#166656)
This PR refactors the name AOTAutogradCacheEntry into AOTAutogradResult, and BundledAOTAutogradCacheEntry into BundledAOTAutogradResult. It also moves all coresponding files to a new file, `aot_autograd_result`, which is analogous to `output_code.py` from Inductor.

Having all these be called cache entries made sense when all we used them for was caching. But with AOT compile using BundledAOTAutogradCacheEntry, we want a more generalized naming structure.

This is a no-op change,  and all existing tests should pass.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166656
Approved by: https://github.com/zhxchen17
ghstack dependencies: #166650
2025-10-31 18:54:09 +00:00
30157d30f0 Add regional aot eager support to AOTAutogradCacheEntry (#166650)
This PR does two things:

- It genericizes `BundledAOTAutogradCacheEntry` to support *any* outputcode, not just CompiledFxGraphs
- It adds a brand new OutputCode for the `aot_eager_regional_inductor` backend, i.e. a graph module that has regional inductor components in it.

This allows BundledAOTAutogradCache to just integrate nicely with inductor out of the box, but more importantly, it allows the result of aot_autograd to be fully serializable when using `aot_eager_regional_inductor`. This will allow us to AOT precompile cases where we have an eager graph that has scooped up inductor bits.

It's a bit unfortunate that the naming makes BundledAOTAutogradCacheEntry sound like its primary use is for caching, but really the more common use is going to be as an AOTAutogradOutput. It may be worth revisiting how to refactor/rename these in a later PR:

- AOTAutogradCacheEntry -> AOTAutogradResult
- BundledAOTAutogradCacheEntry -> BundledAOTAutogradResult

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166650
Approved by: https://github.com/zhxchen17
2025-10-31 18:54:09 +00:00
b470e59c38 partitioner option to ignore partitioner_tag for abstract usage (#166725)
Partitioner functionality is appealing to use in different scenarios (E.g. Autoparallel)

We have special logic about "partitioner_tag" from meta that is only needed for forward/backward split.

Adding optional argument to avoid it and do only generic split based on inputs/outputs.

Potentially we want to make `_extract_graph_with_inputs_outputs` without underscore :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166725
Approved by: https://github.com/bdhirsh
2025-10-31 18:50:02 +00:00
85b85f6c2c Revert "[pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification (#160843)"
This reverts commit 108bb224f77842593009214ebf6258030b934642.

Reverted https://github.com/pytorch/pytorch/pull/160843 on behalf of https://github.com/atalman due to failing internal builds ([comment](https://github.com/pytorch/pytorch/pull/160843#issuecomment-3474354428))
2025-10-31 18:31:32 +00:00
b71966f67b [PyTorch] Improve aarch64 performance of bfloat16 ops - retry (#166028) (#166641)
Summary:

PR allows compiler to better optimize some bfloat16-based operations, when ran on NEON

Retrying to land the code, after noting that these expressions became available in recent compiler versions.

Current CI benchmark ‎binary_test.py will measure affected codepaths.

Benchmarks show measurable improvements on clang-19, when targeting armv9-a+sve2:

Before:
bfloat16 add: 250.503us
bfloat16 sub: 245.674us
bfloat16 neg: 113.945us
bfloat16 abs: 115.953us
bfloat16 reciprocal: 262.602us

After:
bfloat16 add: 203.862us ---> 23% higher throughput
bfloat16 sub: 201.526us ---> 22% higher throughput
bfloat16 neg: 68.416us ---> 67% higher throughput
bfloat16 abs: 71.003us  ---> 63% higher throughput
bfloat16 reciprocal: 177.834us ---> 48% higher throughput

Test Plan:
Correctness:

buck2 test mode/opt //caffe2/test:test_ops
buck2 test mode/opt //caffe2/test:torch

Performance:

buck2 run mode/opt //caffe2/benchmarks/operator_benchmark/fb:operator_benchmark_test

Reviewed By: mcfi

Differential Revision: D85809843

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166641
Approved by: https://github.com/Skylion007, https://github.com/malfet
2025-10-31 18:21:04 +00:00
0947765eb9 Cache even more work for return_and_correct_aliasing (#166365)
Yet another pass found even more work we can move to be done only once. This seems to knock a few microseconds off the DTensor dispatch fast path.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166365
Approved by: https://github.com/bdhirsh
2025-10-31 18:03:05 +00:00
239e7b541a [ROCm][CI] upgrade nightly wheels to ROCm 7.1 (#166730)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166730
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-31 17:30:47 +00:00
ffaa6578b7 Revise deprecation warning for ONNX exporter (#166692)
Updated deprecation warning for ONNX export to reflect the current state.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166692
Approved by: https://github.com/titaiwangms
2025-10-31 17:23:55 +00:00
365ed62f61 Document LibTorch ABI more, add README to headeronly (#166661)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166661
Approved by: https://github.com/mikaylagawarecki, https://github.com/albanD
2025-10-31 17:18:13 +00:00
fcc1063566 Revert "[BE][Typing][Dynamo] Type misc files in torch/_dynamo/variables/ (#166569)"
This reverts commit aa9c96af041b26c9c55adac490f3449b98f27d06.

Reverted https://github.com/pytorch/pytorch/pull/166569 on behalf of https://github.com/Lucaskabela due to Lintrunner not fixed due to race condition at landing ([comment](https://github.com/pytorch/pytorch/pull/166569#issuecomment-3474012637))
2025-10-31 16:59:33 +00:00
121235956b update Node.is_impure check if subgraph contains impure ops (#166609)
Summary:
## Context
when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't.

For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside.
```
parent graph():
    %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {})
    return (getitem,)

submodule graph():
    %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {})
    return (add,)
```
when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure.

But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded.
```
parent after fold graph():
    %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
    return (_fx_const_folded_attrs,)
```

This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure.

## Fix

We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check:
- if a call_module node calls a GraphModule,
- check any call_function nodes are impure ops
- recursively check any call_module nodes that call GraphModule

If the call_module subgraph has impure ops, return True to `is_impure`

Test Plan: added tests to test_fx_const_fold.py

Differential Revision: D85798483

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166609
Approved by: https://github.com/blaine-rister
2025-10-31 16:58:18 +00:00
aa9c96af04 [BE][Typing][Dynamo] Type misc files in torch/_dynamo/variables/ (#166569)
Provides type coverage to ~3000 LOC and 200 methods in  `torch/_dynamo/variables/`

This is the first part of the final step to having 100% strict type coverage in dynamo - see previous comments in https://github.com/pytorch/pytorch/pull/166535 (combined into this one PR because ghstack was giving issues...)

### Coverage report:
```
mypy torch_dynamo/variables --linecount-report /tmp/coverage_log
```
Compare before to after - we go from 3826 to 7221 lines covered

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166569
Approved by: https://github.com/williamwen42
2025-10-31 16:56:50 +00:00
c3b71d5499 [ROCm][CI] remove relaxed tolerance for tf32 tests (#166478)
Instead of relaxing tolerances for certain unit tests that exercise TF32 on MI300, skip the tests until hipblaslt accuracy is improved.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166478
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
Co-authored-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
2025-10-31 16:15:42 +00:00
1e3600b528 [MPS] Move logaddexp/logaddexp2 to Metal and support complex (#166670)
NOTE: Complex inputs are only supported in `logaddexp`. Since `logaddexp2` does not support complex inputs for CPU, it is not enabled for MPS in this PR either.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166670
Approved by: https://github.com/malfet
2025-10-31 16:15:02 +00:00
fee7624bd6 [PT2] set choice handler in config (#166607)
Summary:
We were setting the custom inductor choice using `torch._inductor.virtualized.V.set_choices_handler(CustomInductorChoices())`. However, this leads to inconsistent behaviors, even for jobs that are submitted back to back.

In this diff, we pass in the choice handler via an inductor config and overwrite the default behavior when the config is provided. This sovles the inconsistent behavior.

Test Plan: see D85785892 (internal only)

Differential Revision: D85785879

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166607
Approved by: https://github.com/eellison
2025-10-31 15:40:05 +00:00
24e94e021a [ROCm][CI] create ROCm 7.1 magma tarball (#166693)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166693
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-31 15:20:00 +00:00
69be99ee51 Remove manually synced arch versions in tools/nightly.py (#166616)
Discussed with @atalman offline. To reduce duplicate changes and reduce the number of files to change when updating arch versions.

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166616
Approved by: https://github.com/ezyang
2025-10-31 15:11:28 +00:00
034e951b0c [CUDA][cuBLASLt] addmm -- extend bias fusions to cases with (1 by n) shapes (#166307)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166307
Approved by: https://github.com/eqy
2025-10-31 14:30:41 +00:00
160ab53dd5 Update weight tensor initialization in RMSNormalization (#166550)
Ensure a >1d tensor as weight for ORT compatibility.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166550
Approved by: https://github.com/titaiwangms
2025-10-31 14:29:27 +00:00
5bcfdae71d Revert "Make PT2 compile backprop through custom op without autograd key a hard error (#166367)"
This reverts commit 4acc66f1192ab7743abcc50383aefc5447447f9d.

Reverted https://github.com/pytorch/pytorch/pull/166367 on behalf of https://github.com/atalman due to internal build failures ([comment](https://github.com/pytorch/pytorch/pull/166367#issuecomment-3473150269))
2025-10-31 13:44:05 +00:00
4e8ba37ce3 Revert "[BE] Move GreenContext implementation details to cpp (#166462)"
This reverts commit 5d288bc3f73873887f681e15af83c5525e6a60bd.

Reverted https://github.com/pytorch/pytorch/pull/166462 on behalf of https://github.com/atalman due to Sorry, Reverting. Failure: test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_greencontext_carveout_cuda [GH job link](https://github.com/pytorch/pytorch/actions/runs/18962393091/job/54154156892) [HUD commit link](85b035ca9c) ([comment](https://github.com/pytorch/pytorch/pull/166462#issuecomment-3473060299))
2025-10-31 13:20:48 +00:00
26534e9809 Revert "[GraphPartition] cache get_free_symbol_uses (#166338)"
This reverts commit a6b1ef17173f56ba93ac97ff4384fa4060b5e41e.

Reverted https://github.com/pytorch/pytorch/pull/166338 on behalf of https://github.com/atalman due to Failure: test/nn/test_convolution.py::TestConvolutionNN::test_conv3d_overflow_values [GH job link](https://github.com/pytorch/pytorch/actions/runs/18961173726/job/54149112920) [HUD commit link](a6b1ef1717) ([comment](https://github.com/pytorch/pytorch/pull/166338#issuecomment-3472980329))
2025-10-31 12:57:56 +00:00
657f8c3e21 Revert "Fix torch.full with dynamic tensor fill_value in torch.compile (#166554)"
This reverts commit 32066772b3dee643b1657b8957f32b5ac8b1390a.

Reverted https://github.com/pytorch/pytorch/pull/166554 on behalf of https://github.com/atalman due to Failure: test/nn/test_pooling.py::TestPoolingNNDeviceTypeCPU::test_max_pool_nan_inf_cpu_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/18959368975/job/54144148546) [HUD commit link](32066772b3) ([comment](https://github.com/pytorch/pytorch/pull/166554#issuecomment-3472976911))
2025-10-31 12:55:31 +00:00
b0831930ed [inductor] Mark / restrict tests that only work if ATen is used for matmul (#166518)
These tests only work if max_autotune=False (default), which for matmul means falling back to ATen. This PR just documents / makes that transparent.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166518
Approved by: https://github.com/eellison
2025-10-31 12:29:06 +00:00
c01636e1bc Fixes the sparse tensor issue (#163535)
Fixes #148324

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163535
Approved by: https://github.com/janeyx99
2025-10-31 11:48:31 +00:00
fd68d409ad [xpu][feature] Integrate OneDNN SDPA training forward/backward into XPU OVERRIDEABLE Backend (#162454)
This is the second PR split from https://github.com/pytorch/pytorch/pull/156272

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162454
Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/drisspg
2025-10-31 11:20:38 +00:00
0d3a4f7155 [CD] Enable Inductor performance test for xpu (#166289)
Add Dynamo benchmark performance tests for XPU backend

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166289
Approved by: https://github.com/EikanWang, https://github.com/atalman
2025-10-31 10:52:07 +00:00
108bb224f7 [pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification (#160843)
The goal of this PR is to provide a standard way to create simple treespec instances and hide the implementation details of the `PyTreeSpec` class.

Changes:

1. Add function `treespec_leaf()` to replace `LeafSpec()`.
2. Add function `treespec_tuple(...)` and `treespec_dict(...)` to create treespec for `tuple` / `dict` which is used for `*args` / `**kwargs`. This avoids direct modification to `treespec` instances that rely on the implementation details of the `PyTreeSpec` class.
3. Change `len(spec.children_specs)` to `spec.num_children`.
4. Change `isinstance(spec, LeafSpec)` to `spec.is_leaf()`.

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160843
Approved by: https://github.com/mlazos
2025-10-31 10:33:16 +00:00
fc8ac1216c [4/N] Remove unused loop variables in tests (#166690)
This PR removes unused loop variables in tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166690
Approved by: https://github.com/justinchuby, https://github.com/mlazos
2025-10-31 10:20:48 +00:00
030de07aff [2/N] Use 'is' in callable comparisons (#166685)
It is generally advised to use `is/is not` for comparisons against torch functions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166685
Approved by: https://github.com/xmfan, https://github.com/mlazos
2025-10-31 08:08:07 +00:00
7d67a41db4 make FXConverter.generate use V.fake_mode instead of _detect_fake_mode_from_gm (#166591)
Summary:
FXConverter configurs _node_metadata_hook passing in `fake_mode` explicitly, which is relevant for cases down the line like `_generate_triton_call` that inserts a `triton_kernel_wrapper_mutation` node.

This `fake_mode` is obtained from `_detect_fake_mode_from_gm`, which can be different from inductor set `V.fake_mode`.

For example, while `V.fake_mode` is not None, `_detect_fake_mode_from_gm` can be **None** for a parent graph containing only a submodule which has no input args and only constants
```
parent graph():
    %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {})
    return (getitem,)

submodule graph():
    %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cuda, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {})
    return (add,)

```

Getting this discrepnancy is flawed, it makes `_node_metadata_hook` try running inputs in a different "fake_mode" or no fake_mode when the rest of lowering uses `V.fake_mode`. In some cases where input is placed on custom non-gpu device, it can even complain with "requires device to be started" or tensor device mismatch.

So this diff updates FXConverter.generate to use `V.fake_mode` which is populated by inductor properly.

Test Plan:
added a test `test_const_folded_subgraph` in `test_fxir_backend.py`, this test:
- creates a graph module that calls a subgraph with no inputs and containing only const-foldable ops
- const fold the subgraph
- run FXConverter.generate, expect `fake_mode` used to code-generate is not None

On the prior implementation when `_detect_fake_mode_from_gm` was used, this test would fail as fake_mode would be `None`.

With this change, the test passes, `fake_mode` is properly collected from `V.fake_mode` which is not None.

Differential Revision: D85767475

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166591
Approved by: https://github.com/blaine-rister, https://github.com/mlazos, https://github.com/eellison
2025-10-31 05:52:07 +00:00
85b035ca9c [nativert] Downcast triton double arguments to floats (#166620)
This diff tries to fix a limitation in Sigmoid + Triton interaction, where float arguments are not correctly passed. NativeRT passes float arguments as double, while triton kernels were reading as a float, resulting in wrong values.

---

## Limitations in (de)seriazliation

In triton, float arguments to a kernel are encoded as "fp32" ([code](https://github.com/triton-lang/triton-cpu/blob/main-merged/python/triton/runtime/jit.py#L310-L326)):
```
        elif isinstance(arg, float):
            return ("fp32", None)
```
But it seems like that torch export serde uses double ([code](d2eff5d454/torch/_export/serde/export_schema.thrift (L149))) because Thrift only has the double type:
```
union Argument {
  10: bool as_none;
  20: TensorArgument as_tensor;
  30: list<TensorArgument> as_tensors;
  50: i64 as_int;
  70: list<i64> as_ints;
  80: double as_float;   ===> actually double
...
```
`TritonKernel` constructor loads attributes from a node, where `Constant` represents the variant type. And it only has `double` ([code](d2eff5d454/torch/nativert/graph/Graph.h (L86))):
```
using Constant = std::variant<
    None,
    int64_t,
    std::vector<int64_t>,
    double,    ===> triton float is loaded as double
```

So, NativeRT passes float arguments (originally in Triton) as double to triton kernels. But, all of the triton backends (nvidia, amd and cpu) are reading them as float because the signature still says `fp32`.

D84423898 was the current workaround: wrapping float arguments with tensors.

## The Fix

Fixing the thrift definition isn't viable because Thrift only supports double type. It's also possible to fix on the triton side: it can downcast from double to float. But I needed to fix all backends.

Instead, I think this diff would be the most effective way: when building `TritonKernel`, have downcasted float values, right after loading double arguments.

Test Plan:
```
buck test fbcode//mode/opt-amd-gpu fbcode//caffe2/test:test_export --
```

Differential Revision: D85747160

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166620
Approved by: https://github.com/XueningXu
2025-10-31 03:52:20 +00:00
267d0197bf [dynamo] fix error_on_graph_break bug where non-empty checkpoint results in unwanted graph break resumption (#166586)
Fixes https://github.com/pytorch/pytorch/issues/166589

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166586
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #166476, #166477
2025-10-31 03:36:27 +00:00
1dec8a67a8 [dynamo, nested graph breaks] add disable_nested_graph_breaks decorator/context manager (#166477)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166477
Approved by: https://github.com/Lucaskabela, https://github.com/Skylion007
ghstack dependencies: #166476
2025-10-31 03:36:27 +00:00
797cd80b26 [dynamo, nested graph breaks] codegen dead nested cells correctly (#166476)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166476
Approved by: https://github.com/Lucaskabela
2025-10-31 03:36:27 +00:00
7d39401fa0 Revert "[BE][Typing][Dynamo] Type misc files in torch/_dynamo/variables/ (#166569)"
This reverts commit f1e4c42b6ef3d3cea08ab3babb693e3ce42cf08b.

Reverted https://github.com/pytorch/pytorch/pull/166569 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/166569#issuecomment-3471180280))
2025-10-31 03:31:01 +00:00
e3ae0594d1 Add CUDA MXFP4 scaled mm support via. FBGEMM (#166526)
Summary:

* Pull in `f4f4bf16` from FBGemm to provide MXFP4 support for CUDA
* Add testing

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166526
Approved by: https://github.com/drisspg, https://github.com/ngimel
2025-10-31 03:17:27 +00:00
f1e4c42b6e [BE][Typing][Dynamo] Type misc files in torch/_dynamo/variables/ (#166569)
Provides type coverage to ~3000 LOC and 200 methods in  `torch/_dynamo/variables/`

This is the first part of the final step to having 100% strict type coverage in dynamo - see previous comments in https://github.com/pytorch/pytorch/pull/166535 (combined into this one PR because ghstack was giving issues...)

### Coverage report:
```
mypy torch_dynamo/variables --linecount-report /tmp/coverage_log
```
Compare before to after - we go from 3826 to 7221 lines covered

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166569
Approved by: https://github.com/williamwen42
2025-10-31 02:57:59 +00:00
d3e511f07c [Inductor] support masked vectorization for the tail_loop for fp8 datatype (#163324)
**Summary:**
Support masked vectorization for the tail_loop for fp8 datatype.

**Example:**
```
import torch

def fn(
    x,
    scale,
    zero_point,
    quant_min,
    quant_max,
    dtype,
):
    x = torch.ops.quantized_decomposed.dequantize_per_tensor(
        x,
        scale,
        zero_point,
        quant_min,
        quant_max,
        dtype,
    )
    x = torch.relu(x)
    x = torch.ops.quantized_decomposed.quantize_per_tensor(
        x, scale, zero_point, quant_min, quant_max, dtype
    )
    return x

quant_min = -128
quant_max = 127
dtype = torch.float8_e4m3fn
x = torch.clamp(torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, quant_min, quant_max).to(dtype)
zero_point = 100
scale = 0.01

with torch.no_grad():
    compiled_fn = torch.compile(fn)
    compiled_fn(x, scale, zero_point, quant_min, quant_max, dtype)
```

**Generated code:**

- Before
```
cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0 = async_compile.cpp_pybinding(['const at::Float8_e4m3fn*', 'at::Float8_e4m3fn*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(const at::Float8_e4m3fn* in_ptr0,
                       at::Float8_e4m3fn* out_ptr0)
{
    {
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(441L); x0+=static_cast<int64_t>(16L))
        {
            {
                if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(432L)))
                {
                    auto tmp0 = at::vec::Vectorized<at::Float8_e4m3fn>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                    auto tmp1 = at::vec::convert<float>(tmp0);
                    auto tmp2 = static_cast<float>(100.0);
                    auto tmp3 = at::vec::Vectorized<float>(tmp2);
                    auto tmp4 = tmp1 - tmp3;
                    auto tmp5 = static_cast<float>(0.01);
                    auto tmp6 = at::vec::Vectorized<float>(tmp5);
                    auto tmp7 = tmp4 * tmp6;
                    auto tmp8 = (tmp7);
                    auto tmp9 = at::vec::clamp_min(tmp8, decltype(tmp8)(0));
                    auto tmp10 = tmp9 * tmp3;
                    auto tmp11 = tmp10.round();
                    auto tmp12 = tmp11 + tmp3;
                    auto tmp13 = static_cast<float>(-128.0);
                    auto tmp14 = at::vec::Vectorized<float>(tmp13);
                    auto tmp15 = at::vec::maximum(tmp12, tmp14);
                    auto tmp16 = static_cast<float>(127.0);
                    auto tmp17 = at::vec::Vectorized<float>(tmp16);
                    auto tmp18 = at::vec::minimum(tmp15, tmp17);
                    auto tmp19 = at::vec::convert<at::Float8_e4m3fn>(tmp18);
                    tmp19.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                }
                if(C10_UNLIKELY(x0 >= static_cast<int64_t>(432L) && x0 < static_cast<int64_t>(441L)))
                {
                    for (int64_t x0_tail = static_cast<int64_t>(432L);x0_tail < static_cast<int64_t>(441L); x0_tail++)
                    {
                        auto tmp0 = in_ptr0[static_cast<int64_t>(x0_tail)];
                        auto tmp1 = c10::convert<float>(tmp0);
                        auto tmp2 = static_cast<float>(100.0);
                        auto tmp3 = float(tmp1 - tmp2);
                        auto tmp4 = static_cast<float>(0.01);
                        auto tmp5 = float(tmp3 * tmp4);
                        auto tmp6 = c10::convert<float>(tmp5);
                        auto tmp7 = std::max(tmp6, decltype(tmp6)(0));
                        auto tmp8 = float(tmp7 * tmp2);
                        auto tmp9 = std::nearbyint(tmp8);
                        auto tmp10 = float(tmp9 + tmp2);
                        auto tmp11 = static_cast<float>(-128.0);
                        auto tmp12 = max_propagate_nan(tmp10, tmp11);
                        auto tmp13 = static_cast<float>(127.0);
                        auto tmp14 = min_propagate_nan(tmp12, tmp13);
                        auto tmp15 = c10::convert<at::Float8_e4m3fn>(tmp14);
                        out_ptr0[static_cast<int64_t>(x0_tail)] = tmp15;
                    }
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, = args
        args.clear()
        assert_size_stride(arg0_1, (1, 7, 7, 9), (441, 63, 9, 1))
        buf0 = empty_strided_cpu((1, 7, 7, 9), (441, 63, 9, 1), torch.float8_e4m3fn)
        # [Provenance debug handles] cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0:1
        cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0(arg0_1, buf0)
        del arg0_1
        return (buf0, )
```
- After
```
cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0 = async_compile.cpp_pybinding(['const at::Float8_e4m3fn*', 'at::Float8_e4m3fn*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(const at::Float8_e4m3fn* in_ptr0,
                       at::Float8_e4m3fn* out_ptr0)
{
    {
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(441L); x0+=static_cast<int64_t>(16L))
        {
            {
                if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(432L)))
                {
                    auto tmp0 = at::vec::Vectorized<at::Float8_e4m3fn>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                    auto tmp1 = at::vec::convert<float>(tmp0);
                    auto tmp2 = static_cast<float>(100.0);
                    auto tmp3 = at::vec::Vectorized<float>(tmp2);
                    auto tmp4 = tmp1 - tmp3;
                    auto tmp5 = static_cast<float>(0.01);
                    auto tmp6 = at::vec::Vectorized<float>(tmp5);
                    auto tmp7 = tmp4 * tmp6;
                    auto tmp8 = (tmp7);
                    auto tmp9 = at::vec::clamp_min(tmp8, decltype(tmp8)(0));
                    auto tmp10 = tmp9 * tmp3;
                    auto tmp11 = tmp10.round();
                    auto tmp12 = tmp11 + tmp3;
                    auto tmp13 = static_cast<float>(-128.0);
                    auto tmp14 = at::vec::Vectorized<float>(tmp13);
                    auto tmp15 = at::vec::maximum(tmp12, tmp14);
                    auto tmp16 = static_cast<float>(127.0);
                    auto tmp17 = at::vec::Vectorized<float>(tmp16);
                    auto tmp18 = at::vec::minimum(tmp15, tmp17);
                    auto tmp19 = at::vec::convert<at::Float8_e4m3fn>(tmp18);
                    tmp19.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                }
                if(C10_UNLIKELY(x0 >= static_cast<int64_t>(432L) && x0 < static_cast<int64_t>(441L)))
                {
                    auto tmp0 = at::vec::Vectorized<at::Float8_e4m3fn>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(9L));
                    auto tmp1 = at::vec::convert<float>(tmp0);
                    auto tmp2 = static_cast<float>(100.0);
                    auto tmp3 = at::vec::Vectorized<float>(tmp2);
                    auto tmp4 = tmp1 - tmp3;
                    auto tmp5 = static_cast<float>(0.01);
                    auto tmp6 = at::vec::Vectorized<float>(tmp5);
                    auto tmp7 = tmp4 * tmp6;
                    auto tmp8 = (tmp7);
                    auto tmp9 = at::vec::clamp_min(tmp8, decltype(tmp8)(0));
                    auto tmp10 = tmp9 * tmp3;
                    auto tmp11 = tmp10.round();
                    auto tmp12 = tmp11 + tmp3;
                    auto tmp13 = static_cast<float>(-128.0);
                    auto tmp14 = at::vec::Vectorized<float>(tmp13);
                    auto tmp15 = at::vec::maximum(tmp12, tmp14);
                    auto tmp16 = static_cast<float>(127.0);
                    auto tmp17 = at::vec::Vectorized<float>(tmp16);
                    auto tmp18 = at::vec::minimum(tmp15, tmp17);
                    auto tmp19 = at::vec::convert<at::Float8_e4m3fn>(tmp18);
                    tmp19.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(9L));
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, = args
        args.clear()
        assert_size_stride(arg0_1, (1, 7, 7, 9), (441, 63, 9, 1))
        buf0 = empty_strided_cpu((1, 7, 7, 9), (441, 63, 9, 1), torch.float8_e4m3fn)
        # [Provenance debug handles] cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0:1
        cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0(arg0_1, buf0)
        del arg0_1
        return (buf0, )
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163324
Approved by: https://github.com/Xia-Weiwen, https://github.com/mingfeima, https://github.com/jansel
2025-10-31 02:53:56 +00:00
d3be06cbdc [MTIAGraph][Pytorch][2/n] Add binding for Python to C++, and hook for Pytorch to Fbcode (#165963)
Summary:
This diff is the binding and hook layer for MTIA Graph, including
1. binding between Python and C++
2. hook between Pytorch and mtia fbcode
<img width="1780" height="754" alt="image" src="https://github.com/user-attachments/assets/31e24e5b-8324-42d8-8d3b-59536bc18340" />

[Doc](https://docs.google.com/document/d/1Q3xdZAIqhBvuy2HxGDfJyXVmxYXUEeYSZSwsp7bcJF8/edit?tab=t.osb46a42t6wb#heading=h.ayp9tkk08x00)

Test Plan: Will be tested in the python implementation which will use the binding and hook

Differential Revision: D84457757

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165963
Approved by: https://github.com/malfet, https://github.com/albanD
2025-10-31 02:52:51 +00:00
1129605415 [ROCm][CI] create ROCm 7.1 images for binary builds (#166665)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166665
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-31 02:52:37 +00:00
a6b1ef1717 [GraphPartition] cache get_free_symbol_uses (#166338)
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs.
ee7434be82/torch/_inductor/scheduler.py (L4869-L4885)

I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node.

Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times.
ee7434be82/torch/_inductor/ir.py (L4541-L4543)

This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166338
Approved by: https://github.com/eellison
2025-10-31 02:50:10 +00:00
12577064dd [MPS] Fix crash when max/min ops called for complex types (#166214)
Raise an exception, as it's meaningless and results in segfault otherwise:
```
% python -c "import torch;torch.rand(10, dtype=torch.cfloat, device='mps').amax()"
(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: error: 'mps.reduction_max' op operand #0 must be tensor of mps native type values, but got 'tensor<10xcomplex<f32>>'
(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: note: see current operation: %2 = "mps.reduction_max"(%arg0, %1) <{keep_dims, propagate_nans}> : (tensor<10xcomplex<f32>>, tensor<1xsi32>) -> tensor<1xcomplex<f32>>
(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: error: 'mps.reduction_max' op operand #0 must be tensor of mps native type values, but got 'tensor<10xcomplex<f32>>'
(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: note: see current operation: %2 = "mps.reduction_max"(%arg0, %1) <{keep_dims, propagate_nans}> : (tensor<10xcomplex<f32>>, tensor<1xsi32>) -> tensor<1xcomplex<f32>>
/AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1347: failed assertion `original module failed verification'
zsh: abort      python -c
```

To be tested by `test_ops.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166214
Approved by: https://github.com/dcci, https://github.com/kulinseth, https://github.com/Skylion007
ghstack dependencies: #166272
2025-10-31 02:37:20 +00:00
24b6eb7727 [Inductor] Enable Custom op Autotune Decompositions and Parameter Tuning (#164212)
This PR introduces CustomOp autotuning. It allows user to provide a CustomOpConfig:
(1) to register (optional) multiple decomposition implementations for custom operations and
(2) to register parameter tuning knobs and values they want to tune for the decompositions
so that inductor automatically select the best-performing variant through Inductor's autotune benchmarking.

Example:
```python
 register_custom_op_autotuning(
            custom_op=my_attention_op,
            configs=[
                CustomOpConfig(attention_impl, head_dim=32, method='chunked'),
                CustomOpConfig(attention_impl, head_dim=64, method='tiled'),
                CustomOpConfig(head_dim=128), # no decompositions
            ],
            input_gen_fns={
                "query": lambda fake: torch.randn_like(fake, device='cuda'),
                "key": lambda fake: torch.randn_like(fake, device='cuda'),
                "value": lambda fake: torch.randn_like(fake, device='cuda'),
            }
    )
```

**CustomOpConfig**: Each CustomOpConfig defines exactly one autotuning variant with specific parameter values and optional decomposition implementation with PyTorch aten ops. Users can register their own tuning knobs and optional decomposition functions for the same custom operation. The system automatically benchmarks all variants to select the best performing. If no decomposition is provided in the config, the CustomOp's default implementation will be used.

**Custom Input Generation**: Users can provide custom input generators via an optional `input_gen_fns` to control how synthetic inputs are created during benchmarking. This enables more realistic performance testing by generating inputs that match expected data distributions and characteristics for each tensor argument.

**More Examples with autotune logs:**:
1. Allow user to register customOp decompositions with tuning parameters for autotuning. Example usage:
```python
from torch._inductor.kernel.custom_op import CustomOpConfig, register_custom_op_autotuning

def decompose_k_implementation(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4) -> torch.Tensor:
    """Matrix multiply with k-way decomposition."""
         # Implementation...with k_splits

@torch.library.custom_op("my_lib::decompose_k", mutates_args=())
def test_decompose_k_op(
        a: torch.Tensor, b: torch.Tensor, k_splits: int
    ) -> torch.Tensor:
        return decompose_k_implementation(a, b, k_splits)

# Register autotuning with different k_splits values
register_custom_op_autotuning(
    custom_op=test_decompose_k_op,
    configs=[
        CustomOpConfig(decompose_k_implementation, k_splits=2),
        CustomOpConfig(decompose_k_implementation, k_splits=32),
        CustomOpConfig(decompose_k_implementation, k_splits=64),
        CustomOpConfig(k_splits=128), # can make decomposition optional, then use default impl test_decompose_k_op
        CustomOpConfig(k_splits=256)
    ],
    input_gen_fns={
        "a": lambda fake: torch.randn_like(fake, device='cuda') * 0.1,
        "b": lambda fake: torch.randn_like(fake, device='cuda') * 0.1,
    }
)
```

Example result:
```
{"num_choices": 6, "num_triton_choices": 0, "best_kernel": "test_decompose_k_autotuned_fallback_default", "best_time": 0.09980800002813339}
AUTOTUNE test_decompose_k_autotuned(256x65536, 65536x1024)
strides: [65536, 1], [1024, 1]
dtypes: torch.float16, torch.float16
  test_decompose_k_autotuned_fallback_default 0.0998 ms 100.0%
  test_decompose_k_autotuned_decompose_k_implementation_k_splits_2_0 0.1096 ms 91.0% CustomOp decompose_k_implementation_k_splits_2
  test_decompose_k_autotuned_decompose_k_implementation_k_splits_32_1 0.1277 ms 78.2% CustomOp decompose_k_implementation_k_splits_32
  test_decompose_k_autotuned_decompose_k_implementation_k_splits_64_2 0.1454 ms 68.6% CustomOp decompose_k_implementation_k_splits_64
  test_decompose_k_autotuned_decompose_k_implementation_k_splits_128_3 0.1536 ms 65.0% CustomOp decompose_k_implementation_k_splits_128
  test_decompose_k_autotuned_decompose_k_implementation_k_splits_256_4 0.2084 ms 47.9% CustomOp decompose_k_implementation_k_splits_256
```

2. Allow user to tune parameter knob by passing the parameter and values in the CustomOpConfig.
**Example**
```python
def mlp_variants(input_tensor, gate_weight, up_weight, down_weight, method):
    """MLP implementation with different computational approaches."""
    if method == 0:
        # Standard separate matmuls
        # ... implementation
    elif method == 1:
        # Batched approach with torch.mm
        # ... implementation
    elif method == 2:
        # Fused weights approach
        # ... implementation

@torch.library.custom_op("my_lib::mlp_op", mutates_args=())
        def mlp_op(
            input_tensor: torch.Tensor,
            gate_weight: torch.Tensor,
            up_weight: torch.Tensor,
            down_weight: torch.Tensor,
            method: int,
        ) -> torch.Tensor:
            return mlp_variants(
                input_tensor, gate_weight, up_weight, down_weight, method=method
            )

register_custom_op_autotuning(
    custom_op=mlp_op,
    configs=[
        CustomOpConfig(method=0),
        CustomOpConfig(method=1),
        CustomOpConfig(method=2),
        # method=0 is the default fallback in the original op
    ],
    input_gen_fns={
        "input_tensor": lambda fake: torch.randn_like(fake, device='cuda') * 0.1,
        "gate_weight": lambda fake: torch.randn_like(fake, device='cuda') * 0.05,
        # ... other input generators
    }
)

```

Example result:
```
AUTOTUNE test_mlp_autotuned(4x32x512, 512x1024, 512x1024, 1024x256)
  test_mlp_autotuned_mlp_variants_method_2 0.0181 ms 100.0% CustomOp mlp_variants_method_2
  test_mlp_autotuned_mlp_variants_method_1 0.0185 ms 97.8% CustomOp mlp_variants_method_1
  test_mlp_autotuned_mlp_default_fallback_method_0 0.0198 ms 91.4% CustomOp fallback
```

### Test Suite (`test/inductor/test_custom_op_autotune.py`)

*   **RMSNorm autotuning**: Tests different RMSNorm implementations with dynamic input shapes
*   **MLP autotuning**: Tests different MLP decomposition and tuning "method" parameter
*   **DecomposeK**: Tests different k_splits values for matrix multiplication decomposition with k dim split
*   **Multi-parameter tuning**: Tests configs with multiple tuning parameters (scale_mode, chunk_size)

### Next Step:
- Enable Max-autotune with user passed in max-autotune config. https://github.com/pytorch/pytorch/pull/165526/files
- Support inline epilogue fusion for selected best customop decomposition with surrounding elementwise ops. https://github.com/pytorch/pytorch/pull/165952/files
- Support customop autotune considering fusion with multiTemplateBuffer. WIP

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164212
Approved by: https://github.com/zou3519
2025-10-31 02:28:00 +00:00
32066772b3 Fix torch.full with dynamic tensor fill_value in torch.compile (#166554)
Fixes #166253

## Summary
When `torch.full` is called with a 0-D tensor as `fill_value` inside a `torch.compile`'d function, the value was being incorrectly cached, causing subsequent calls with different values to return the first value.

## Root Cause
The Dynamo handler for `torch.full` was calling `aten._local_scalar_dense` to convert tensor fill_values to Python scalars at compile time, which baked the value into the compiled graph as a constant.

## Solution
Modified the Dynamo handler to decompose `torch.full(size, tensor_fill_value)` into `empty(size).fill_(tensor_fill_value)` when `fill_value` is a `TensorVariable`, keeping the fill value dynamic in the compiled graph.

## Testing
Added test case that verifies torch.full works correctly with dynamic tensor fill_values across multiple calls and dtypes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166554
Approved by: https://github.com/Lucaskabela
2025-10-31 00:56:02 +00:00
47f0024310 [CI][BE] Factor out repeated test code (#166481)
Into `_run_single_arg_fwd`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166481
Approved by: https://github.com/Skylion007
2025-10-31 00:52:50 +00:00
98d640bb11 Remove AT_USE_HIPSPARSE_GENERIC_API (#166393)
This macro is not used in OSS anymore.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166393
Approved by: https://github.com/ezyang
2025-10-31 00:49:09 +00:00
5d288bc3f7 [BE] Move GreenContext implementation details to cpp (#166462)
- Remove all complex defines logic from the header
- Make GreenContext constructor private, as  it should only be created via the static method as singleton
- Delete unused `getContext` and `getGreenContext` methods
- Rename `CUDA_HAS_GREEN_CONTEXT` to `HAS_CUDA_GREEN_CONTEXT()`, which results in compilation error if one accidentally makes a typo
- Suppress `-Wunused-private-field` is GreenContext is not available
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166462
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-10-31 00:48:01 +00:00
bfb47ec50e [dynamo] support tracing new typing union syntax X | Y (#166599)
To do in a followup - I think there's an approach to reconstruct typing variables.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166599
Approved by: https://github.com/SherlockNoMad, https://github.com/anijain2305, https://github.com/Skylion007
2025-10-30 23:59:27 +00:00
7a0cd8ed09 [ROCm] Disable __builtin_amdgcn_rcpf for gfx90a (#166454)
Improves accuracy for some failing tests.

test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py::TestClipGradNormWorldSize4::test_clip_grad_norm_2d [GH job link](https://github.com/pytorch/pytorch/actions/runs/18930221123/job/54046876467) [HUD commit link](f20bf77874)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166454
Approved by: https://github.com/jerrymannil, https://github.com/jeffdaily
2025-10-30 23:39:00 +00:00
984e64b2cd [inductor] Fix constant folder (#166655)
Fixes https://fb.workplace.com/groups/1028545332188949/permalink/1351999569843522/ where the resulting graph of constant folder uses a sym node which has been created later. Graph diff: https://www.internalfb.com/intern/diffing/?paste_number=2014609054

Before:
```
    %full_65 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%sym_size_int_47, 768], 1), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0, pin_memory: False})
    %select_18 : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%full_65, 1, 0), kwargs = {})
    %mul_2792 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%select_18, 0), kwargs = {})
    %embedding_4 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%_uv__surface_embeddings_weight, %mul_2792), kwargs = {})
```

After:
```
    %full_65 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%sym_size_int_47, 768], 1), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0, pin_memory: False})
    %full_default_1 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%sym_size_int_150], 0), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0, pin_memory: False})
    %embedding_4 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%_uv__surface_embeddings_weight, %full_default_1), kwargs = {})
    ...
    %sym_size_int_150 : [num_users=7] = call_function[target=torch.ops.aten.sym_size.int](args = (%view_193, 0), kwargs = {})
```

I couldn't figure out a small repro for this :/

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166655
Approved by: https://github.com/eellison
2025-10-30 22:51:28 +00:00
b9bcb37f40 [DebugMode] store stringify args by default (#166347)
DebugMode currently stores dispatch call args & kwargs, which is all intermediate tensors and more. This quickly OOMed on GPU when trying to debug some torchtitan / llama 8b models.

This defaults to storing the stringified version, adding a flag `DebugMode(store_original_args=True)` if users want to store the original args as-is (and for BC).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166347
Approved by: https://github.com/yushangdi
2025-10-30 22:12:23 +00:00
7e3b9d105e [CP][BE][2/2] Refactor the code structure (#166501)
Our CP codebase now contains several files and we are adding more. This
PR refactors the code to consolidate the files into a context_parallel
folder but keep the import so that the existing users of CP won't be
affected.

Unfortunately, we have to split this PR into two PRs as the PyTorch
infra cannot accept a PR with 3000+ LoC change and git cannot recognize
that _context_parallel/_attention.py is moved from _attention.py because
we want to keep BC.

This is the second PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166501
Approved by: https://github.com/Skylion007
ghstack dependencies: #166456
2025-10-30 22:07:07 +00:00
45c3f02d69 [ROCm] moved gfx1100 back to experimental status for AOTriton (#166397)
According to next commit to AOTriton:
8625c4faee

These changes missed in 0.11b release:
https://github.com/pytorch/pytorch/pull/161754

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166397
Approved by: https://github.com/jeffdaily
2025-10-30 21:43:01 +00:00
f5543e3741 [wip] fix searchsorted non dense (#165064)
Fix for https://github.com/pytorch/pytorch/issues/163528

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165064
Approved by: https://github.com/benjaminglass1, https://github.com/mlazos
2025-10-30 21:21:24 +00:00
5fc2c7a2a1 [ROCm][inductor] More configs for pointwise kernels. (#166470)
This config improves performance by 250% on some kernels that contain `t1.atomic_add(...)`. Again, we conditionalize for ROCm/HIP, so there is no impact to NV.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166470
Approved by: https://github.com/PaulZhang12, https://github.com/mlazos, https://github.com/eellison, https://github.com/jansel
2025-10-30 21:20:12 +00:00
7692fa09cd [Code Clean] Clean asserts in torch/ao/quantization/fx/* (#165420)
Replace assert statements with explicit if/raise patterns in:

- torch/ao/quantization/fx/* (177 errors)

fix partialy #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165420
Approved by: https://github.com/RohitRathore1, https://github.com/fffrog, https://github.com/albanD
2025-10-30 20:53:36 +00:00
df71b70727 [cuDNN][conv] Re-enable cuDNN for 3D convolutions (fixed in 9.15+) (#166480)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166480
Approved by: https://github.com/Skylion007, https://github.com/malfet
2025-10-30 20:47:20 +00:00
80ba6e458f Add warning when users have incomplete setup for type checking (#166603)
Looking for feedback on this approach.
Received user reports of spurious pyrefly errors for users using hg instead of git. I think this was due to the fact that when using a venv and git, `make setup-env` installs requirements and pulls from a nightly torch wheel, which is needed for pyrefly to type check properly.

Initial documentation for `make setup-env` I found here: https://github.com/pytorch/pytorch/blob/main/CONTRIBUTING.md#developing-pytorch

Testing:
```
hg clone --git ssh://git@github.com/pytorch/pytorch.git
conda create -n pytorch_env python=3.10 # (or manually create venv instead of using script)
cd pytorch
pip install -r requirements.txt
pip install -r requirements-build.txt
lintrunner init
# check how many pyrefly errors - 15,709 errors (11,693 ignored)
lintrunner # confirm error message / warning appears
>>> General linter failure:
  Warning (PYREFLY) nightly-wheel-not-run
    pytorch-nightly.pth not found. You may need to run make setup-env or make
    setup-env-conda to install nightly binaries and type stubs.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166603
Approved by: https://github.com/aorenste
2025-10-30 20:37:44 +00:00
0d50e5d8d4 [3/N] Fix unused loop variables (#166509)
This PR removes unused loop variables in tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166509
Approved by: https://github.com/Lucaskabela, https://github.com/Skylion007
2025-10-30 20:13:51 +00:00
99b05d1b78 Better 1x128, 128x128 error handling on non-Hopper (#166639)
Summary:

Blockwise 1x128 and 128x128 scaling is only available on CUDA >= 12.9
and only on Hopper GPUs. Attempting to run on B200 would give a
hard-to-debug `CUBLAS_STATUS_NOT_SUPPORTED`.

Add a more helpful `NotImplementedError` to catch this case.

Also more explicitly disable ROCm builds for relevant methods, based on
lack of support per [hipBLASlt
docs](https://rocm.docs.amd.com/projects/hipBLASLt/en/latest/reference/datatypes.html#_CPPv4N28hipblasLtMatmulMatrixScale_t40HIPBLASLT_MATMUL_MATRIX_SCALE_VEC128_32FE).

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166639
Approved by: https://github.com/drisspg
2025-10-30 20:13:06 +00:00
f911d64750 [CUDA] xFail max-autotune grouped gemm tests on devices with insufficient SM count (#165921)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165921
Approved by: https://github.com/ngimel
2025-10-30 20:05:07 +00:00
52db60170d Enable verify_dynamo on Python 3.13 (#166497)
Dynamo now supports Python 3.13.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166497
Approved by: https://github.com/Lucaskabela, https://github.com/williamwen42
2025-10-30 19:52:32 +00:00
56838bad5f [CP][BE][1/2] Refactor the code structure (#166456)
Our CP codebase now contains several files and we are adding more. This PR refactors the code to consolidate the files into a context_parallel folder but keep the import so that the existing users of CP won't be affected.

Unfortunately, we have to split this PR into two PRs as the PyTorch infra cannot accept a PR with 3000+ LoC change and git cannot recognize that _context_parallel/_attention.py is moved from _attention.py because we want to keep BC.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166456
Approved by: https://github.com/Skylion007
2025-10-30 19:46:49 +00:00
ad3a56ab98 Add a compile-time flag to trigger verbose logging for device-side asserts (#166171)
Summary:
Using `CUDA_KERNEL_ASSERT_PRINTF` inside kernels allows us to log invalid values to the console (that can be in turn used to surface _hopefully_ more clearer error messages).

This does have an impact in the number of registers needed for the values being logged (I confirmed via diffing PTX that there is no other impact relative to using `__assert_fail`)

To avoid causing perf bottlenecks, this change adds a compile-time switch to enable more verbose errors in some of the common kernels that cause DSAs. There is also a Buck config that can be used to configure this switch more conveniently.

## Alternatives considered
I considered making the behavior of `CUDA_KERNEL_ASSERT_PRINTF` controllable via a compile-time macro instead of writing another wrapper for it but there are kernels where the extra register pressure is not as severe and in those cases, having more useful error messages by default is pretty useful.

Test Plan:
## Simple Python Driver:
```
# scatter_errors.py
import torch
def main() -> None:
    a = torch.rand(128, device="cuda:0")
    idx = torch.randint(0, 128, (100,), device="cuda:0")
    idx[0] = 9999
    b = torch.scatter(a, 0, idx, 555.0)
    print(b)
```

When running normally via:
```
$ buck2 run @//mode/opt  :scatter_errors
```
we see the followng DSA message:
```
fbcode/caffe2/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:410: operator(): block: [0,0,0], thread: [0,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
```

Running via:
```
$  buck2 run @//mode/opt -c fbcode.c10_enable_verbose_assert=1 :scatter_errors
```
however produces:
```
[CUDA_KERNEL_ASSERT] fbcode/caffe2/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:410: operator(): block: [0,0,0], thread: [0,0,0]: Assertion failed: `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"`: Expected 0 <= idx_dim < index_size (128), but got idx_dim = 9999
```

Differential Revision: D85185987

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166171
Approved by: https://github.com/ngimel
2025-10-30 19:43:46 +00:00
a7fd0b4001 [ROCm][CI] fix disk space message (#166645)
Fixes diskspace cutoff to say that the machine does not have difference=100 - diskspace_cutoff_int space available.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166645
Approved by: https://github.com/jeffdaily
2025-10-30 19:38:34 +00:00
181ee3bd42 fix: Add missing signals_to_handle to launcher logging (#166631)
Fixes #166630

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166631
Approved by: https://github.com/Skylion007

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2025-10-30 19:31:25 +00:00
0ec0549823 Introduce a new API torch.xpu.get_per_process_memory_fraction (#165511)
# Motivation
Aligned with other backends, this PR introduces a new API torch.xpu.get_per_process_memory_fraction to allow user to retrieve the allowed memory fraction per a single process.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165511
Approved by: https://github.com/EikanWang, https://github.com/ezyang
ghstack dependencies: #165508, #165509, #165510
2025-10-30 19:30:09 +00:00
8221ee6db9 [xpu] Fix type annotation for ProcessGroupXCCL (#166418)
After #163049, this PR fixes the type annotations to match the actual implementation for ProcessGroupXCCL::Options.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166418
Approved by: https://github.com/guangyey, https://github.com/ezyang
2025-10-30 19:29:06 +00:00
b939de26d1 Avoid writing temporary modules to disk (#157713)
In some cases the warning from #147744 still gets emitted because [atexit hooks aren't called](https://github.com/python/cpython/pull/114279).

Even in those cases, if the atexit hooks _were_ called you could end up with issues due to the directory being deleted in one process, but still being used elsewhere.

It's better all round to load these modules entirely in-memory.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157713
Approved by: https://github.com/xush6528
2025-10-30 19:11:16 +00:00
694db5f549 Use 'is' in callable comparisons (#166624)
Just like we use `is/is not` for class comparisons, it is generally advised to use `is/is not` for comparisons against torch functions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166624
Approved by: https://github.com/Lucaskabela, https://github.com/Skylion007
2025-10-30 19:00:09 +00:00
639a0b1239 Remove torch.distributed.tensor.OpSchema.has_symints (#163667)
It appears to be unused based on `cd torch; rg has_symints`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163667
Approved by: https://github.com/xmfan, https://github.com/azahed98, https://github.com/albanD
ghstack dependencies: #162990
2025-10-30 18:57:17 +00:00
398775a43e [CodeClean] Replace std::runtime_error with TORCH_CHECK (#165119)
As the title stated.

**Changes**:
- torch/csrc/inductor(Part 2)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165119
Approved by: https://github.com/janeyx99
ghstack dependencies: #165139
2025-10-30 18:43:58 +00:00
fcd5f8c352 [CodeClean] Remove the Unused MACRO for AOT Inductor Runtime (#165139)
As the title stated.

- AOTI_TORCH_CHECK depend on TORCH_CHECK_MSG which located in c10/util/Exception.h, which maybe break BC
- AOTI_TORCH_CHECK is not used everywhere
- STD_TORCH_CHECK have ABI check tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165139
Approved by: https://github.com/Skylion007, https://github.com/janeyx99
2025-10-30 18:43:58 +00:00
4acc66f119 Make PT2 compile backprop through custom op without autograd key a hard error (#166367)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166367
Approved by: https://github.com/bdhirsh
2025-10-30 18:43:07 +00:00
8f40a0c634 Revert "address DDE in matmul decomp (#166541)"
This reverts commit 90519402c2006237f891289a0afdec804515aa73.

Reverted https://github.com/pytorch/pytorch/pull/166541 on behalf of https://github.com/atalman due to breaks internal test ([comment](https://github.com/pytorch/pytorch/pull/166541#issuecomment-3469382334))
2025-10-30 18:11:33 +00:00
a5c3c08d10 [Pytorch] Use exp_u20 for aarch64's erf (#166594)
Summary:
After a precision study, we concluded it is ok to use ACL's exp function on f32's erf()
We can keep erf inline this way.

Benchmarks show about 91% higher throughput when processing a tensor of 1M elements, compiling with clang-19:

Before:
f32 erf: 2539.179us
After:
f32 erf: 1329.063us

Test Plan:
Correctness:

buck2 test mode/opt //caffe2/test:test_ops
buck2 test mode/opt //caffe2/test:torch

Performance:

buck2 run mode/opt //caffe2/benchmarks/operator_benchmark/fb:operator_benchmark_test

Differential Revision: D85730452

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166594
Approved by: https://github.com/mcfi, https://github.com/fadara01
2025-10-30 18:09:05 +00:00
a553ea9ea4 Fix missing symbol when printing guards (#165723)
Fixes #165177

When converting guards to sources if we were unable to get the expected symbol from symbol_to_source then try to get it from var_to_sources.

I was unable to make a simpler repro than what was described in the issue (which relies on llama3 - so inappropriate for a unit test).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165723
Approved by: https://github.com/bobrenjc93
2025-10-30 18:03:51 +00:00
ba71e9ca9a [DeviceMesh] Isolate pg creation logic in Device Mesh into a separate func _init_one_process_group (#166614)
To makes pg cache change easier and code modularization, we isolate the logic of process group creation into a separate function named `_init_one_process_group`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166614
Approved by: https://github.com/lw
2025-10-30 17:57:41 +00:00
694d205143 Revert "shrink_group implementation to expose ncclCommShrink API (#164518)"
This reverts commit 311ea0dec0c50f395e6dac7b3875e81ee243fceb.

Reverted https://github.com/pytorch/pytorch/pull/164518 on behalf of https://github.com/atalman due to breaks internal builds Error: from logging_utils import ( ModuleNotFoundError: No module named 'logging_utils' ([comment](https://github.com/pytorch/pytorch/pull/164518#issuecomment-3469308568))
2025-10-30 17:52:29 +00:00
629293f568 bucket all reduce (#166528)
Bucket all reduce in bucketer, thanks to @IvanKobzarev's earlier pr.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166528
Approved by: https://github.com/IvanKobzarev
ghstack dependencies: #166527
2025-10-30 17:12:34 +00:00
c37802a8c4 use multi-dtype bucketing (#166527)
Make the bucketer use multi-dtype bucketing for all gathers.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166527
Approved by: https://github.com/IvanKobzarev, https://github.com/ezyang
2025-10-30 16:54:49 +00:00
0a3ac47c0a Revert "[user-streams] Fix stream graph output semantics (#164819)"
This reverts commit f5cb9a4c68d9271c58ef4d3257210984b8e85099.

Reverted https://github.com/pytorch/pytorch/pull/164819 on behalf of https://github.com/atalman due to breaks CI ([comment](https://github.com/pytorch/pytorch/pull/164819#issuecomment-3469018283))
2025-10-30 16:53:32 +00:00
e83be7042e Fix pyrefly errors on main (#166548)
Fixes existing errors to keep noise from lintrunner to a minimum

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166548
Approved by: https://github.com/Lucaskabela, https://github.com/mlazos
2025-10-30 16:47:27 +00:00
fb545fb068 Add MXFP4 grouped gemm support via. FBGEMM kernels (#166530)
Summary:

* Extend `_scaled_grouped_mm_v2` to include MXFP4 support
* Add testing to existing grouped routines

Test Plan:

```
pytest -svv -k "mxfp4 and group" test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166530
Approved by: https://github.com/drisspg
2025-10-30 16:46:11 +00:00
2df2c316e2 [devx] Fix invalid symbol definition emitted in fx_graph_runnable.py (#166529)
Summary: When emitting symbolic variable definition in fx_graph_runnable.py, we need to check if a SymNode is actually an expression, so that we won't generate something like "s27*s53**2 = 36".

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166529
Approved by: https://github.com/mlazos
ghstack dependencies: #166432
2025-10-30 16:40:12 +00:00
08b0a8f11a [Inductor] Fix an inductor_provenance bug (#166432)
Summary: Fix an inductor_provenance related error seen when running TORCH_COMPILE_DEBUG generated fx_graph_runnable.py.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166432
Approved by: https://github.com/mlazos
2025-10-30 16:40:12 +00:00
3f1824742c Revert "Fix comparing inductor actual strides vs bw graph for activations should not throw DDE. (#166277)"
This reverts commit b2a0f90501dd3a16a6ccaf4c49e1c10f6df4ce1d.

Reverted https://github.com/pytorch/pytorch/pull/166277 on behalf of https://github.com/atalman due to Breaks internal executorch tests ([comment](https://github.com/pytorch/pytorch/pull/166277#issuecomment-3468696623))
2025-10-30 15:49:23 +00:00
bbb7d2270b [inductor] print 0.0 as 0 for triton (#164291)
Fixes https://github.com/pytorch/pytorch/issues/164157
Fixes https://github.com/pytorch/pytorch/issues/164086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164291
Approved by: https://github.com/bobrenjc93, https://github.com/mlazos
2025-10-30 15:15:25 +00:00
6a5a436624 DTensor: C++ compute_global_tensor_info (#162990)
compute_global_tensor_info is on the hot path for DTensor.{from,to}_local. More incremental progress toward C++.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162990
Approved by: https://github.com/ezyang
2025-10-30 15:10:54 +00:00
ad559072db [triton][sigmoid] Fix kernel cache and serialization issue for triton sigmoid + CUDA kernel bug (#166568)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166568
Approved by: https://github.com/minjang
2025-10-30 14:54:54 +00:00
ad02bd13df Revert "[user-streams] Add current stream source (#165211)"
This reverts commit 79aee77381b21d41c77148e5ff84c4b351aaf144.

Reverted https://github.com/pytorch/pytorch/pull/165211 on behalf of https://github.com/atalman due to failure: test/test_python_dispatch.py::TestPythonDispatch::test_return_stream [GH job link](https://github.com/pytorch/pytorch/actions/runs/18942517662/job/54086481693) [HUD commit link](7563f61cc8) ([comment](https://github.com/pytorch/pytorch/pull/165211#issuecomment-3468332362))
2025-10-30 14:34:43 +00:00
7563f61cc8 Make bucketing aware of collective LIFO semantics (#166324)
In the initial pr for overlapping preserving bucketing, for a graph like:

```
def foo(...):
     ag = all_gather(...)
     hiding_compute = mm(...)
     wait(ag)
```

We would add dependencies from mm -> ag, and wait from wait -> hiding_compute, to prevent bucketing reordering these collectives so that overlap no long occurred. however, there is an additional way for bucketing to prevent overlap.

If we were to reorder another collective so the graph looked like:

```
def foo(...):
     ag = all_gather(...)
     ar = all_reduce(...)
     wait(ar)
     hiding_compute = mm(...)
     wait(ag)
```

Overlap would not occur, because the wait for the all reduce would also force realization of every collective enqueued on the same stream prior to the all reduce. NCCL uses a single stream per process group.

To model, we set a set a strict ordering of all collective starts, waits, and hiding compute initially when bucketing. Then, when trying to add a collective to a bucket, we will see if we interfere with overlap for all of the following possible bucketings:

[move collective start to bucket start, move bucket start to collective start] x [move collective wait to bucket wait x move bucket wait to collective wait].

For any of these positions, we check if overlap would have been interfered with because of stream queue semantics. Then, if not, we remove the moving start and wait from the constrained ordering of collectives, and see if it's topologically valid to merge the nodes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166324
Approved by: https://github.com/IvanKobzarev
ghstack dependencies: #166309
2025-10-30 13:37:00 +00:00
fa8e073a4e Revert "[triton][sigmoid] Fix kernel cache and serialization issue for triton sigmoid + CUDA kernel bug (#166568)"
This reverts commit d46d8d6f54b15ded4f2483c7bde31be124281ab8.

Reverted https://github.com/pytorch/pytorch/pull/166568 on behalf of https://github.com/atalman due to Failed test/test_extension_utils.py::TestExtensionUtils::test_external_module_register_with_renamed_backend [GH job link](https://github.com/pytorch/pytorch/actions/runs/18931754443/job/54050880312) [HUD commit link](d46d8d6f54) ([comment](https://github.com/pytorch/pytorch/pull/166568#issuecomment-3468008894))
2025-10-30 13:31:47 +00:00
95b5534773 Revert "[user-streams] Track symbolic current stream (#165212)"
This reverts commit a5335263d32b5be2b2647661334d81225c3cc3fc.

Reverted https://github.com/pytorch/pytorch/pull/165212 on behalf of https://github.com/atalman due to test/test_rename_privateuse1_to_existing_device.py::TestRenamePrivateuseoneToExistingBackend::test_external_module_register_with_existing_backend [GH job link](https://github.com/pytorch/pytorch/actions/runs/18930365446/job/54046768884) [HUD commit link](a5335263d3) ([comment](https://github.com/pytorch/pytorch/pull/165212#issuecomment-3467968796))
2025-10-30 13:24:56 +00:00
9ee1afbf66 Revert "[user-streams] Handle returning the current stream with/without device index (#165356)"
This reverts commit f1af679270392c83e03808c8af5e2cbe3cdf16ce.

Reverted https://github.com/pytorch/pytorch/pull/165356 on behalf of https://github.com/atalman due to test/test_rename_privateuse1_to_existing_device.py::TestRenamePrivateuseoneToExistingBackend::test_external_module_register_with_existing_backend [GH job link](https://github.com/pytorch/pytorch/actions/runs/18930365446/job/54046768884) [HUD commit link](a5335263d3) ([comment](https://github.com/pytorch/pytorch/pull/165356#issuecomment-3467967061))
2025-10-30 13:22:24 +00:00
f60751024e Revert "[2/N] Add strict parameter to Python zip calls (#166257)"
This reverts commit 39e5cdddf7e57881c52473d1288a66f0222527e1.

Reverted https://github.com/pytorch/pytorch/pull/166257 on behalf of https://github.com/atalman due to Failing: test/distributed/fsdp/test_fsdp_mixed_precision.py::TestFSDPTrainEval::test_train_ema_eval_flow [GH job link](https://github.com/pytorch/pytorch/actions/runs/18934047991/job/54057218160) [HUD commit link](39e5cdddf7) ([comment](https://github.com/pytorch/pytorch/pull/166257#issuecomment-3467955332))
2025-10-30 13:20:00 +00:00
2de4cf2102 [1/N] Remove unused loop variables (#166258)
This PR removes unused loop variables.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166258
Approved by: https://github.com/Lucaskabela, https://github.com/mlazos
2025-10-30 12:22:25 +00:00
369f2d6951 [3/N] fix typo in other folders (#166606)
fix typo in other folders

#166374
#166126

_typos.toml
```bash
[files]
extend-exclude = ["tools/linter/dictionary.txt"]
[default.extend-words]
nd = "nd"
arange = "arange"
Nd = "Nd"
GLOBALs = "GLOBALs"
hte = "hte"
iy = "iy"
PN = "PN"
Dout = "Dout"
optin = "optin"
gam = "gam"
PTD = "PTD"
Sur = "Sur"
nin = "nin"
tme = "tme"
inpt = "inpt"
mis = "mis"
Raison = "Raison"
ouput = "ouput"
nto = "nto"
Onwer = "Onwer"
callibrate = "callibrate"
ser = "ser"
Metdata = "Metdata"
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166606
Approved by: https://github.com/ezyang
2025-10-30 10:30:40 +00:00
32920926f0 [xpu][fix] [Inductor] Avoid using tl.sqrt_rn on XPU before triton is ready (#165740)
Fixes #165738

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165740
Approved by: https://github.com/etaf, https://github.com/EikanWang, https://github.com/chuanqi129, https://github.com/desertfire
2025-10-30 09:24:24 +00:00
39e5cdddf7 [2/N] Add strict parameter to Python zip calls (#166257)
This PR adds `strict=True/False` to zip calls in test utils. strict=True is passed when possible.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166257
Approved by: https://github.com/janeyx99
2025-10-30 08:10:10 +00:00
2829d48bd1 [xpu][test][1/N] Port 3 fsdp distributed test cases to Intel GPU (#161476)
For https://github.com/pytorch/pytorch/issues/114850, we will port 3 distributed tests to Intel GPU.
We could enable Intel GPU with the following methods and try the best to keep the original code styles:

- use "torch.accelerator.current_accelerator()" to determine the accelerator backend
- use "requires_accelerator_dist_backend" to enable "xccl"
- enabled XPU for some test path
- skip some test cases that Intel GPU does not support

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161476
Approved by: https://github.com/weifengpy, https://github.com/guangyey
2025-10-30 07:30:04 +00:00
f1af679270 [user-streams] Handle returning the current stream with/without device index (#165356)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165356
Approved by: https://github.com/anijain2305
ghstack dependencies: #164304, #164522, #164819, #165211, #165212
2025-10-30 07:20:25 +00:00
d46d8d6f54 [triton][sigmoid] Fix kernel cache and serialization issue for triton sigmoid + CUDA kernel bug (#166568)
Differential Revision: D85792537

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166568
Approved by: https://github.com/minjang
2025-10-30 06:17:39 +00:00
a5335263d3 [user-streams] Track symbolic current stream (#165212)
merge into stream tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165212
Approved by: https://github.com/anijain2305
ghstack dependencies: #164304, #164522, #164819, #165211
2025-10-30 04:58:53 +00:00
79aee77381 [user-streams] Add current stream source (#165211)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165211
Approved by: https://github.com/anijain2305
ghstack dependencies: #164304, #164522, #164819
2025-10-30 04:58:53 +00:00
f5cb9a4c68 [user-streams] Fix stream graph output semantics (#164819)
Preivously, we would stash a single stream value we constructed at trace time in a global and return the same value from repeated calls to the graph.

With this PR, we construct the stream value in advance, reference the constructed value in the graph via the lookup table, and if that value is returned as an output, read the value from the lookup table and return it (in bytecode, not as a graph output, since we don't support arbitrary stream outputs).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164819
Approved by: https://github.com/anijain2305
ghstack dependencies: #164304, #164522
2025-10-30 04:58:46 +00:00
f20bf77874 [audio hash update] update the pinned audio hash (#166597)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned audio hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166597
Approved by: https://github.com/pytorchbot
2025-10-30 04:28:30 +00:00
75f798e05b [inductor][mi350] add tech specs for MI350 (#166576)
Summary:
was digging through matmul padding for other work, and I noticed that the compute bound checking won't work on MI350 since we haven't supplied the tech specs yet.

I added MI350 specs following the predefined format

Test Plan: CI

Differential Revision: D85804980

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166576
Approved by: https://github.com/leitian
2025-10-30 03:46:52 +00:00
476b149a00 bwd pass (#164504)
**Summary**
This implements the backward pass for the Varlen API and registers `_varlen_attn()` as a custom op.

**Benchmarking**

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings:

- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs

|        | Variable Length API | SDPA     |
|--------|--------------------|----------|
| Runtime | 0.8189142608642578 ms       | 3.263883056640625 ms  |
| TFLOPs | 268.652       | 158.731  |

We can see that runtime for Varlen is >3x faster

**Testing**

Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen gradients vs SDPA.

For custom op testing, `test_custom_op_registration` uses logging mode to verify that `_varlen_attn()` was called and tests with `torch.compile`. `test_custom_op_compliances` uses `torch.library.opcheck()` to verify.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164504
Approved by: https://github.com/drisspg
2025-10-30 03:46:37 +00:00
845da9c817 [ONNX] Ignore pyrefly errors in torchlib (#166588)
Fixes #166475

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166588
Approved by: https://github.com/titaiwangms
2025-10-30 03:43:52 +00:00
0918bf321c [xpu][test] Reuse native_mm and mix_order_reduction for Intel GPU. (#166384)
This PR reused native_mm and mix_order_reduction for Intel GPU and enabled the corresonding test.
Fixes #165370

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166384
Approved by: https://github.com/jansel
2025-10-30 03:38:35 +00:00
90519402c2 address DDE in matmul decomp (#166541)
Address https://github.com/pytorch/pytorch/issues/165081
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166541
Approved by: https://github.com/mlazos
2025-10-30 03:19:29 +00:00
791ca80d3a Enable local tensor mode for DTensor attention and convolution tests (#166406)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166406
Approved by: https://github.com/ezyang
2025-10-30 02:48:02 +00:00
5cbdade914 Fix a syntactic error in test_indexing.py (#166390)
This PR fixes a syntactic error in test_indexing.py by a misplaced `if else` expression.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166390
Approved by: https://github.com/jerryzh168
2025-10-30 02:28:01 +00:00
0187db88d4 [ROCm][CI] Create periodic-rocm-mi200.yml (#166544)
* We are separating out the rocm jobs of the periodic workflow
* We are introducing a new label `ciflow/periodic-rocm-mi200` to allow us to run distributed tests only on ROCm runners, without triggering many other jobs on the `periodic.yml` workflow (via `ciflow/periodic`)
* This new workflow will also be triggered via the `ciflow/periodic`, thus maintaining the old status quo.
* We are reverting to the `linux.rocm.gpu.4` label since it targets a lot more CI nodes at this point than the K8s/ARC-based `linux.rocm.gpu.mi250.4` label, as that is still having some network/scaling issues.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166544
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-30 02:08:07 +00:00
311ea0dec0 shrink_group implementation to expose ncclCommShrink API (#164518)
Closes #164529

To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink) API to PyTorch.

This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization.

For more info:  [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518
Approved by: https://github.com/kwen2501
2025-10-30 01:50:54 +00:00
cf7756da38 Bump uv from 0.9.5 to 0.9.6 in /.ci/lumen_cli (#166578)
Bumps [uv](https://github.com/astral-sh/uv) from 0.9.5 to 0.9.6.
- [Release notes](https://github.com/astral-sh/uv/releases)
- [Changelog](https://github.com/astral-sh/uv/blob/main/CHANGELOG.md)
- [Commits](https://github.com/astral-sh/uv/compare/0.9.5...0.9.6)

---
updated-dependencies:
- dependency-name: uv
  dependency-version: 0.9.6
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-29 18:28:14 -07:00
e380028a51 [inductor][choices] lookup table choices 1/3 (#164978)
\# why

- enable users to control which choices get used on which inputs
- reduce lowering time, and pin kernel selection, by selecting
  them for the inputs

\# what

- a new InductorChoices subclass that implements a lookup table
- a README explaining the usage
- corresponding testing

- currently only supports templates that go through
  `V.choices.get_template_configs`

\# testing

```
python3 -bb -m pytest test/inductor/test_lookup_table.py -v
```

Differential Revision: [D85685743](https://our.internmc.facebook.com/intern/diff/D85685743)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164978
Approved by: https://github.com/PaulZhang12, https://github.com/eellison, https://github.com/mlazos
2025-10-30 01:28:01 +00:00
b4403bfc62 Add waitcounters for torch.compile subprocess pool (#164527)
Summary:
This ads waitcounter for whether or not the pool is running, as well as if we
are running jobs.

This also ads waitcounters for the first job within a pool. First job and running are working correctly. The job waitcounter seems to either be detecting a leak of a job, or is broken subtly.

Test Plan:
We've tested this internally and see valid ods metrics.

Note that we may be leaking jobs, or the job logic may not be handling an exception correctly.

Differential Revision: D83705931

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164527
Approved by: https://github.com/masnesral
2025-10-30 01:15:26 +00:00
12c12466b0 [ROCm][CI] remove amdgpu from install_rocm.sh (#166575)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166575
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-30 01:08:33 +00:00
f4d05feb7a Repro dynamo issue for union typed annotation (#166443)
when nested function has type annotation using "|",  it fails.

it works fine with `Union[torch.Tensor, DTensor]` tho.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166443
Approved by: https://github.com/anijain2305
2025-10-30 01:05:15 +00:00
7481622237 [symbolic shapes] remove maybe_guard_rel warning (#166553)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166553
Approved by: https://github.com/laithsakka
2025-10-30 00:57:28 +00:00
b2a0f90501 Fix comparing inductor actual strides vs bw graph for activations should not throw DDE. (#166277)
Fix https://github.com/pytorch/pytorch/issues/163894

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166277
Approved by: https://github.com/Lucaskabela
2025-10-30 00:34:05 +00:00
14d4a77495 disable current modes instead of no dispatch in estimation (#166571)
otherwise, the custom estimation's TorchDispatchModes will be disabled.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166571
Approved by: https://github.com/SherlockNoMad, https://github.com/bdhirsh
2025-10-29 23:24:41 +00:00
3d4ca228be Remove METADATA.bzl files (#166574)
(meta-internal, should not be synced)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166574
Approved by: https://github.com/bigfootjon
2025-10-29 23:17:41 +00:00
447 changed files with 13213 additions and 6056 deletions

View File

@ -195,13 +195,16 @@ case "$tag" in
NINJA_VERSION=1.9.0
TRITON=yes
;;
pytorch-linux-jammy-xpu-n-py3)
pytorch-linux-jammy-xpu-n-py3 | pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks)
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=11
VISION=yes
XPU_VERSION=2025.2
NINJA_VERSION=1.9.0
TRITON=yes
if [[ $tag =~ "benchmarks" ]]; then
INDUCTOR_BENCHMARKS=yes
fi
;;
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks)
ANACONDA_PYTHON_VERSION=3.10

View File

@ -3,7 +3,7 @@
set -eux
ACL_VERSION=${ACL_VERSION:-"v25.02"}
ACL_VERSION=${ACL_VERSION:-"v52.6.0"}
ACL_INSTALL_DIR="/acl"
# Clone ACL

View File

@ -40,11 +40,7 @@ EOF
# Default url values
rocm_baseurl="http://repo.radeon.com/rocm/apt/${ROCM_VERSION}"
amdgpu_baseurl="https://repo.radeon.com/amdgpu/${ROCM_VERSION}/ubuntu"
# Add amdgpu repository
UBUNTU_VERSION_NAME=`cat /etc/os-release | grep UBUNTU_CODENAME | awk -F= '{print $2}'`
echo "deb [arch=amd64] ${amdgpu_baseurl} ${UBUNTU_VERSION_NAME} main" > /etc/apt/sources.list.d/amdgpu.list
# Add rocm repository
wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -

View File

@ -12,8 +12,8 @@ function do_install() {
rocm_version_nodot=${rocm_version//./}
# https://github.com/icl-utk-edu/magma/pull/65
MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec
# post merge of https://github.com/icl-utk-edu/magma/pull/65
MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f
magma_archive="magma-rocm${rocm_version_nodot}-${MAGMA_VERSION}-1.tar.bz2"
rocm_dir="/opt/rocm"

View File

@ -54,12 +54,15 @@ ENV OPENSSL_DIR /opt/openssl
RUN rm install_openssl.sh
ARG INDUCTOR_BENCHMARKS
ARG ANACONDA_PYTHON_VERSION
ENV ANACONDA_PYTHON_VERSION=$ANACONDA_PYTHON_VERSION
COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps.sh
COPY ./common/common_utils.sh common_utils.sh
COPY ci_commit_pins/huggingface-requirements.txt huggingface-requirements.txt
COPY ci_commit_pins/timm.txt timm.txt
COPY ci_commit_pins/torchbench.txt torchbench.txt
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt
# Install XPU Dependencies
ARG XPU_VERSION

View File

@ -6,7 +6,7 @@ dependencies = [
"GitPython==3.1.45",
"docker==7.1.0",
"pytest==7.3.2",
"uv==0.9.5"
"uv==0.9.6"
]
[tool.setuptools]

View File

@ -1,7 +1,7 @@
SHELL=/usr/bin/env bash
DOCKER_CMD ?= docker
DESIRED_ROCM ?= 7.0
DESIRED_ROCM ?= 7.1
DESIRED_ROCM_SHORT = $(subst .,,$(DESIRED_ROCM))
PACKAGE_NAME = magma-rocm
# inherit this from underlying docker image, do not pass this env var to docker
@ -16,6 +16,7 @@ DOCKER_RUN = set -eou pipefail; ${DOCKER_CMD} run --rm -i \
magma-rocm/build_magma.sh
.PHONY: all
all: magma-rocm71
all: magma-rocm70
all: magma-rocm64
@ -24,6 +25,11 @@ clean:
$(RM) -r magma-*
$(RM) -r output
.PHONY: magma-rocm71
magma-rocm71: DESIRED_ROCM := 7.1
magma-rocm71:
$(DOCKER_RUN)
.PHONY: magma-rocm70
magma-rocm70: DESIRED_ROCM := 7.0
magma-rocm70:

View File

@ -6,8 +6,8 @@ set -eou pipefail
# The script expects DESIRED_CUDA and PACKAGE_NAME to be set
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
# https://github.com/icl-utk-edu/magma/pull/65
MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec
# post merge of https://github.com/icl-utk-edu/magma/pull/65
MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f
# Folders for the build
PACKAGE_FILES=${ROOT_DIR}/magma-rocm/package_files # metadata
@ -20,7 +20,7 @@ mkdir -p ${PACKAGE_DIR} ${PACKAGE_OUTPUT}/linux-64 ${PACKAGE_BUILD} ${PACKAGE_RE
# Fetch magma sources and verify checksum
pushd ${PACKAGE_DIR}
git clone https://github.com/jeffdaily/magma
git clone https://github.com/icl-utk-edu/magma
pushd magma
git checkout ${MAGMA_VERSION}
popd

View File

@ -426,7 +426,7 @@ fi
if [[ "$BUILD_ENVIRONMENT" != *libtorch* && "$BUILD_ENVIRONMENT" != *bazel* ]]; then
# export test times so that potential sharded tests that'll branch off this build will use consistent data
# don't do this for libtorch as libtorch is C++ only and thus won't have python tests run on its build
python tools/stats/export_test_times.py
PYTHONPATH=. python tools/stats/export_test_times.py
fi
# don't do this for bazel or s390x or riscv64 as they don't use sccache
if [[ "$BUILD_ENVIRONMENT" != *s390x* && "$BUILD_ENVIRONMENT" != *riscv64* && "$BUILD_ENVIRONMENT" != *-bazel-* ]]; then

View File

@ -572,6 +572,8 @@ fi
if [[ "${TEST_CONFIG}" == *cpu* ]]; then
DYNAMO_BENCHMARK_FLAGS+=(--device cpu)
elif [[ "${TEST_CONFIG}" == *xpu* ]]; then
DYNAMO_BENCHMARK_FLAGS+=(--device xpu)
else
DYNAMO_BENCHMARK_FLAGS+=(--device cuda)
fi
@ -665,6 +667,8 @@ test_perf_for_dashboard() {
device=cuda_b200
elif [[ "${TEST_CONFIG}" == *rocm* ]]; then
device=rocm
elif [[ "${TEST_CONFIG}" == *xpu* ]]; then
device=xpu
fi
for mode in "${modes[@]}"; do
@ -1757,7 +1761,7 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then
else
# Do this after checkout_install_torchbench to ensure we clobber any
# nightlies that torchbench may pull in
if [[ "${TEST_CONFIG}" != *cpu* ]]; then
if [[ "${TEST_CONFIG}" != *cpu* && "${TEST_CONFIG}" != *xpu* ]]; then
install_torchrec_and_fbgemm
fi
PYTHONPATH=/torchbench test_dynamo_benchmark torchbench "$id"

View File

@ -27,7 +27,9 @@ runs:
docker system prune -af
diskspace_new=$(df -H --output=pcent ${docker_root_dir} | sed -n 2p | sed 's/%//' | sed 's/ //')
if [[ "$diskspace_new" -gt "$diskspace_cutoff" ]] ; then
echo "Error: Available diskspace is less than $diskspace_cutoff percent. Not enough diskspace."
diskspace_cutoff_int=$((diskspace_cutoff + 0))
difference=$((100 - diskspace_cutoff_int))
echo "Error: Available diskspace is less than $difference percent. Not enough diskspace."
echo "$msg"
exit 1
else

View File

@ -1 +1 @@
69bbe7363897764f9e758d851cd0340147d27f94
3b0e7a6f192ca2715e7e6cbe5db007aea7165fe2

View File

@ -19,6 +19,7 @@ ciflow_push_tags:
- ciflow/inductor-perf-test-nightly-rocm-mi300
- ciflow/inductor-perf-test-nightly-rocm-mi355
- ciflow/inductor-perf-test-nightly-x86-zen
- ciflow/inductor-perf-test-nightly-xpu
- ciflow/inductor-periodic
- ciflow/inductor-rocm
- ciflow/linux-aarch64
@ -26,6 +27,7 @@ ciflow_push_tags:
- ciflow/nightly
- ciflow/op-benchmark
- ciflow/periodic
- ciflow/periodic-rocm-mi200
- ciflow/periodic-rocm-mi300
- ciflow/pull
- ciflow/quantization-periodic

View File

@ -11,11 +11,17 @@ architectures:
* Latest XPU
"""
import json
import os
import re
from pathlib import Path
from typing import Optional
# NOTE: Please also update the CUDA sources in `PIP_SOURCES` in tools/nightly.py when changing this
SCRIPT_DIR = Path(__file__).absolute().parent
REPO_ROOT = SCRIPT_DIR.parent.parent
CUDA_ARCHES = ["12.6", "12.8", "12.9", "13.0"]
CUDA_STABLE = "12.8"
CUDA_ARCHES_FULL_VERSION = {
@ -31,8 +37,7 @@ CUDA_ARCHES_CUDNN_VERSION = {
"13.0": "9",
}
# NOTE: Please also update the ROCm sources in `PIP_SOURCES` in tools/nightly.py when changing this
ROCM_ARCHES = ["6.4", "7.0"]
ROCM_ARCHES = ["7.0", "7.1"]
XPU_ARCHES = ["xpu"]
@ -137,9 +142,48 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {
}
def get_nccl_wheel_version(arch_version: str) -> str:
import re
# Used by tools/nightly.py
PYTORCH_NIGHTLY_PIP_INDEX_URL = "https://download.pytorch.org/whl/nightly"
NIGHTLY_SOURCE_MATRIX = {
"cpu": dict(
name="cpu",
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cpu",
supported_platforms=["Linux", "macOS", "Windows"],
accelerator="cpu",
)
}
CUDA_NIGHTLY_SOURCE_MATRIX = {
f"cuda-{major}.{minor}": dict(
name=f"cuda-{major}.{minor}",
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu{major}{minor}",
supported_platforms=["Linux", "Windows"],
accelerator="cuda",
)
for major, minor in (map(int, version.split(".")) for version in CUDA_ARCHES)
}
ROCM_NIGHTLY_SOURCE_MATRIX = {
f"rocm-{major}.{minor}": dict(
name=f"rocm-{major}.{minor}",
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/rocm{major}.{minor}",
supported_platforms=["Linux"],
accelerator="rocm",
)
for major, minor in (map(int, version.split(".")) for version in ROCM_ARCHES)
}
XPU_NIGHTLY_SOURCE_MATRIX = {
"xpu": dict(
name="xpu",
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/xpu",
supported_platforms=["Linux"],
accelerator="xpu",
)
}
NIGHTLY_SOURCE_MATRIX.update(CUDA_NIGHTLY_SOURCE_MATRIX)
NIGHTLY_SOURCE_MATRIX.update(ROCM_NIGHTLY_SOURCE_MATRIX)
NIGHTLY_SOURCE_MATRIX.update(XPU_NIGHTLY_SOURCE_MATRIX)
def get_nccl_wheel_version(arch_version: str) -> str:
requirements = map(
str.strip, re.split("[;|]", PYTORCH_EXTRA_INSTALL_REQUIREMENTS[arch_version])
)
@ -147,17 +191,14 @@ def get_nccl_wheel_version(arch_version: str) -> str:
def read_nccl_pin(arch_version: str) -> str:
from pathlib import Path
nccl_pin_path = os.path.join(
Path(__file__).absolute().parents[2],
".ci",
"docker",
"ci_commit_pins",
f"nccl-cu{arch_version[:2]}.txt",
nccl_pin_path = (
REPO_ROOT
/ ".ci"
/ "docker"
/ "ci_commit_pins"
/ f"nccl-cu{arch_version[:2]}.txt"
)
with open(nccl_pin_path) as f:
return f.read().strip()
return nccl_pin_path.read_text().strip()
def validate_nccl_dep_consistency(arch_version: str) -> None:
@ -165,7 +206,8 @@ def validate_nccl_dep_consistency(arch_version: str) -> None:
wheel_ver = get_nccl_wheel_version(arch_version)
if not nccl_release_tag.startswith(f"v{wheel_ver}"):
raise RuntimeError(
f"{arch_version} NCCL release tag version {nccl_release_tag} does not correspond to wheel version {wheel_ver}"
f"{arch_version} NCCL release tag version {nccl_release_tag} "
f"does not correspond to wheel version {wheel_ver}"
)
@ -412,7 +454,14 @@ def generate_wheels_matrix(
return ret
validate_nccl_dep_consistency("13.0")
validate_nccl_dep_consistency("12.9")
validate_nccl_dep_consistency("12.8")
validate_nccl_dep_consistency("12.6")
arch_version = ""
for arch_version in CUDA_ARCHES:
validate_nccl_dep_consistency(arch_version)
del arch_version
if __name__ == "__main__":
# Used by tools/nightly.py
(SCRIPT_DIR / "nightly_source_matrix.json").write_text(
json.dumps(NIGHTLY_SOURCE_MATRIX, indent=4) + "\n"
)

View File

@ -38,6 +38,10 @@ on:
default: ""
description: |
List of tests to include (empty string implies default list)
dashboard-tag:
required: false
type: string
default: ""
disable-monitor:
description: |
[Experimental] Disable utilization monitoring for tests.
@ -58,6 +62,11 @@ on:
required: false
type: number
default: 1
secrets:
HUGGING_FACE_HUB_TOKEN:
required: false
description: |
HF Auth token to avoid rate limits when downloading models or datasets from hub
permissions:
id-token: write
contents: read
@ -196,6 +205,8 @@ jobs:
PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }}
PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }}
TESTS_TO_INCLUDE: ${{ inputs.tests-to-include }}
DASHBOARD_TAG: ${{ inputs.dashboard-tag }}
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
timeout-minutes: ${{ fromJson(steps.test-timeout.outputs.timeout) }}
run: |
# Fetch aws credential from IMDs
@ -246,6 +257,8 @@ jobs:
-e PYTORCH_TEST_RERUN_DISABLED_TESTS \
-e TESTS_TO_INCLUDE \
-e ZE_AFFINITY_MASK \
-e HUGGING_FACE_HUB_TOKEN \
-e DASHBOARD_TAG \
--env-file="/tmp/github_env_${GITHUB_RUN_ID}" \
--ulimit stack=10485760:83886080 \
--ulimit core=0 \

View File

@ -36,7 +36,7 @@ jobs:
runs-on: linux.9xlarge.ephemeral
strategy:
matrix:
tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm6.4", "rocm7.0", "cpu"]
tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm7.0", "rocm7.1", "cpu"]
steps:
- name: Build docker image
uses: pytorch/pytorch/.github/actions/binary-docker-build@main

View File

@ -52,8 +52,8 @@ jobs:
{ tag: "cuda12.9" },
{ tag: "cuda12.8" },
{ tag: "cuda12.6" },
{ tag: "rocm6.4" },
{ tag: "rocm7.0" },
{ tag: "rocm7.1" },
{ tag: "cpu" },
]
steps:

View File

@ -34,7 +34,7 @@ jobs:
id-token: write
strategy:
matrix:
rocm_version: ["70", "64"]
rocm_version: ["71", "70"]
steps:
- name: Checkout PyTorch
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

View File

@ -54,8 +54,8 @@ jobs:
{ name: "manylinuxaarch64-builder", tag: "cuda12.9", runner: "linux.arm64.2xlarge.ephemeral" },
{ name: "manylinuxaarch64-builder", tag: "cuda12.8", runner: "linux.arm64.2xlarge.ephemeral" },
{ name: "manylinuxaarch64-builder", tag: "cuda12.6", runner: "linux.arm64.2xlarge.ephemeral" },
{ name: "manylinux2_28-builder", tag: "rocm6.4", runner: "linux.9xlarge.ephemeral" },
{ name: "manylinux2_28-builder", tag: "rocm7.0", runner: "linux.9xlarge.ephemeral" },
{ name: "manylinux2_28-builder", tag: "rocm7.1", runner: "linux.9xlarge.ephemeral" },
{ name: "manylinux2_28-builder", tag: "cpu", runner: "linux.9xlarge.ephemeral" },
{ name: "manylinux2_28_aarch64-builder", tag: "cpu-aarch64", runner: "linux.arm64.2xlarge.ephemeral" },
{ name: "manylinux2_28-builder", tag: "xpu", runner: "linux.9xlarge.ephemeral" },

View File

@ -55,7 +55,7 @@ jobs:
docker-image: ["pytorch/manylinux2_28-builder:cpu"]
include:
- device: "rocm"
rocm_version: "7.0"
rocm_version: "7.1"
runs_on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge"
- device: "cuda"
rocm_version: ""

View File

@ -67,6 +67,7 @@ jobs:
pytorch-linux-jammy-py3.12-halide,
pytorch-linux-jammy-xpu-n-1-py3,
pytorch-linux-jammy-xpu-n-py3,
pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks,
pytorch-linux-jammy-py3-clang18-asan,
pytorch-linux-jammy-py3-clang12-onnx,
pytorch-linux-jammy-linter,

View File

@ -384,124 +384,6 @@ jobs:
github-token: ${{ secrets.GITHUB_TOKEN }}
uses: ./.github/workflows/_binary-upload.yml
libtorch-rocm6_4-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
needs: get-label-type
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.4
GPU_ARCH_VERSION: "6.4"
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
timeout-minutes: 300
build_name: libtorch-rocm6_4-shared-with-deps-release
build_environment: linux-binary-libtorch
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
libtorch-rocm6_4-shared-with-deps-release-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- libtorch-rocm6_4-shared-with-deps-release-build
- get-label-type
runs-on: linux.rocm.gpu.mi250
timeout-minutes: 240
env:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.4
GPU_ARCH_VERSION: "6.4"
GPU_ARCH_TYPE: rocm
SKIP_ALL_TESTS: 1
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
- uses: actions/download-artifact@v4.1.7
name: Download Build Artifacts
with:
name: libtorch-rocm6_4-shared-with-deps-release
path: "${{ runner.temp }}/artifacts/"
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: ROCm set GPU_FLAG
run: |
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
- name: configure aws credentials
id: aws_creds
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }}
uses: aws-actions/configure-aws-credentials@v4
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
aws-region: us-east-1
role-duration-seconds: 18000
- name: Calculate docker image
id: calculate-docker-image
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
with:
docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }}
docker-image-name: libtorch-cxx11-builder
custom-tag-prefix: rocm6.4
docker-build-dir: .ci/docker
working-directory: pytorch
- name: Pull Docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Test Pytorch binary
uses: ./pytorch/.github/actions/test-pytorch-binary
env:
DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Teardown ROCm
uses: ./.github/actions/teardown-rocm
libtorch-rocm6_4-shared-with-deps-release-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: libtorch-rocm6_4-shared-with-deps-release-test
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.4
GPU_ARCH_VERSION: "6.4"
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
build_name: libtorch-rocm6_4-shared-with-deps-release
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
uses: ./.github/workflows/_binary-upload.yml
libtorch-rocm7_0-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
@ -619,3 +501,121 @@ jobs:
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
uses: ./.github/workflows/_binary-upload.yml
libtorch-rocm7_1-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
needs: get-label-type
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm7.1
GPU_ARCH_VERSION: "7.1"
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: rocm7.1
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
timeout-minutes: 300
build_name: libtorch-rocm7_1-shared-with-deps-release
build_environment: linux-binary-libtorch
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
libtorch-rocm7_1-shared-with-deps-release-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- libtorch-rocm7_1-shared-with-deps-release-build
- get-label-type
runs-on: linux.rocm.gpu.mi250
timeout-minutes: 240
env:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm7.1
GPU_ARCH_VERSION: "7.1"
GPU_ARCH_TYPE: rocm
SKIP_ALL_TESTS: 1
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: rocm7.1
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
- uses: actions/download-artifact@v4.1.7
name: Download Build Artifacts
with:
name: libtorch-rocm7_1-shared-with-deps-release
path: "${{ runner.temp }}/artifacts/"
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: ROCm set GPU_FLAG
run: |
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
- name: configure aws credentials
id: aws_creds
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }}
uses: aws-actions/configure-aws-credentials@v4
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
aws-region: us-east-1
role-duration-seconds: 18000
- name: Calculate docker image
id: calculate-docker-image
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
with:
docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }}
docker-image-name: libtorch-cxx11-builder
custom-tag-prefix: rocm7.1
docker-build-dir: .ci/docker
working-directory: pytorch
- name: Pull Docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Test Pytorch binary
uses: ./pytorch/.github/actions/test-pytorch-binary
env:
DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Teardown ROCm
uses: ./.github/actions/teardown-rocm
libtorch-rocm7_1-shared-with-deps-release-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: libtorch-rocm7_1-shared-with-deps-release-test
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm7.1
GPU_ARCH_VERSION: "7.1"
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: rocm7.1
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
build_name: libtorch-rocm7_1-shared-with-deps-release
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
uses: ./.github/workflows/_binary-upload.yml

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,148 @@
name: inductor-perf-nightly-xpu
on:
push:
tags:
- ciflow/inductor-perf-test-nightly-xpu/*
schedule:
- cron: 30 17 * * *
workflow_dispatch:
inputs:
training:
description: Run training (on by default)?
required: false
type: boolean
default: true
inference:
description: Run inference (on by default)?
required: false
type: boolean
default: true
default:
description: Run inductor_default?
required: false
type: boolean
default: false
dynamic:
description: Run inductor_dynamic_shapes?
required: false
type: boolean
default: false
cppwrapper:
description: Run inductor_cpp_wrapper?
required: false
type: boolean
default: false
cudagraphs:
description: Run inductor_cudagraphs?
required: false
type: boolean
default: false
freezing_cudagraphs:
description: Run inductor_cudagraphs with freezing for inference?
required: false
type: boolean
default: false
aotinductor:
description: Run aot_inductor for inference?
required: false
type: boolean
default: false
maxautotune:
description: Run inductor_max_autotune?
required: false
type: boolean
default: false
benchmark_configs:
description: The list of configs used the benchmark
required: false
type: string
default: inductor_huggingface_perf,inductor_timm_perf,inductor_torchbench_perf,cachebench
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions: read-all
jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
opt_out_experiments: lf
xpu-n-py3_10-inductor-benchmark-build:
name: xpu-n-py3.10-inductor-benchmark
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks
runner: linux.c7i.12xlarge
test-matrix: |
{ include: [
{ config: "inductor_huggingface_perf_xpu", shard: 1, num_shards: 5, runner: "linux.idc.xpu" },
{ config: "inductor_huggingface_perf_xpu", shard: 2, num_shards: 5, runner: "linux.idc.xpu" },
{ config: "inductor_huggingface_perf_xpu", shard: 3, num_shards: 5, runner: "linux.idc.xpu" },
{ config: "inductor_huggingface_perf_xpu", shard: 4, num_shards: 5, runner: "linux.idc.xpu" },
{ config: "inductor_huggingface_perf_xpu", shard: 5, num_shards: 5, runner: "linux.idc.xpu" },
{ config: "inductor_timm_perf_xpu", shard: 1, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_timm_perf_xpu", shard: 2, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_timm_perf_xpu", shard: 3, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_timm_perf_xpu", shard: 4, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_timm_perf_xpu", shard: 5, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_timm_perf_xpu", shard: 6, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_torchbench_perf_xpu", shard: 1, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_torchbench_perf_xpu", shard: 2, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_torchbench_perf_xpu", shard: 3, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_torchbench_perf_xpu", shard: 4, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_torchbench_perf_xpu", shard: 5, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_torchbench_perf_xpu", shard: 6, num_shards: 6, runner: "linux.idc.xpu" },
]}
secrets: inherit
xpu-n-py3_10-inductor-benchmark-test-nightly:
permissions:
id-token: write
contents: read
if: github.event_name != 'workflow_dispatch'
name: xpu-n-py3.10-inductor-benchmark
uses: ./.github/workflows/_xpu-test.yml
needs: xpu-n-py3_10-inductor-benchmark-build
with:
build-environment: linux-jammy-xpu-n-py3.10
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-false-cppwrapper-true-aotinductor-true-freezing_cudagraphs-false-cudagraphs_low_precision-false
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
timeout-minutes: 720
# Disable monitor in perf tests for more investigation
disable-monitor: true
monitor-log-interval: 10
monitor-data-collect-interval: 2
secrets: inherit
xpu-n-py3_10-inductor-benchmark-test:
permissions:
id-token: write
contents: read
if: github.event_name == 'workflow_dispatch'
name: xpu-n-py3.10-inductor-test
uses: ./.github/workflows/_xpu-test.yml
needs: xpu-n-py3_10-inductor-benchmark-build
with:
build-environment: linux-jammy-xpu-n-py3.10
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
timeout-minutes: 720
disable-monitor: false
monitor-log-interval: 15
monitor-data-collect-interval: 4
secrets: inherit

View File

@ -0,0 +1,84 @@
name: periodic-rocm-mi200
on:
schedule:
# We have several schedules so jobs can check github.event.schedule to activate only for a fraction of the runs.
# Also run less frequently on weekends.
- cron: 45 0,8,16 * * 1-5
- cron: 45 4 * * 0,6
- cron: 45 4,12,20 * * 1-5
- cron: 45 12 * * 0,6
- cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests
push:
tags:
- ciflow/periodic/*
- ciflow/periodic-rocm-mi200/*
branches:
- release/*
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
llm-td:
if: github.repository_owner == 'pytorch'
name: before-test
uses: ./.github/workflows/llm_td_retrieval.yml
permissions:
id-token: write
contents: read
target-determination:
name: before-test
uses: ./.github/workflows/target_determination.yml
needs: llm-td
permissions:
id-token: write
contents: read
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch'
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-jammy-rocm-py3_10-build:
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-rocm-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
test-matrix: |
{ include: [
{ config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.4", owners: ["module:rocm", "oncall:distributed"] },
{ config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.4", owners: ["module:rocm", "oncall:distributed"] },
{ config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.4", owners: ["module:rocm", "oncall:distributed"] },
]}
secrets: inherit
linux-jammy-rocm-py3_10-test:
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
secrets: inherit

View File

@ -204,37 +204,6 @@ jobs:
test-matrix: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc11-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-rocm-py3_10-build:
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-rocm-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
test-matrix: |
{ include: [
{ config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.mi250.4", owners: ["module:rocm", "oncall:distributed"] },
{ config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.mi250.4", owners: ["module:rocm", "oncall:distributed"] },
{ config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.mi250.4", owners: ["module:rocm", "oncall:distributed"] },
]}
secrets: inherit
linux-jammy-rocm-py3_10-test:
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-cuda12_8-py3-gcc11-slow-gradcheck-build:
name: linux-jammy-cuda12.8-py3-gcc11-slow-gradcheck
uses: ./.github/workflows/_linux-build.yml

View File

@ -6,6 +6,7 @@ on:
- pull
- trunk
- periodic
- periodic-rocm-mi200
- periodic-rocm-mi300
- inductor
- unstable

1
.gitignore vendored
View File

@ -143,6 +143,7 @@ scripts/release_notes/*.json
sccache-stats*.json
lint.json
merge_record.json
.github/scripts/nightly_source_matrix.json
# These files get copied over on invoking setup.py
torchgen/packaged/*

View File

@ -374,7 +374,7 @@ cmake_dependent_option(
"Build the lazy Torchscript backend, not compatible with mobile builds" ON
"NOT INTERN_BUILD_MOBILE" OFF)
cmake_dependent_option(BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF)
cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin fodler"
cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin folder"
OFF "USE_CUDA" OFF)
cmake_dependent_option(USE_KLEIDIAI "Use KleidiAI for the ARM CPU & AARCH64 architecture." ON
"CPU_AARCH64" OFF)

View File

@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
if(USE_CUDA)
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped|f4f4bf16).*")
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")

View File

@ -19,6 +19,13 @@ inline namespace CPU_CAPABILITY {
#error "Big endian is not supported."
#endif
// GCC does not properly optimize bf16 operators
#if defined(__ARM_FEATURE_BF16) && (__clang_major__ >= 19)
#define BF16_ARITHMETIC_SUPPORTED() 1
#else
#define BF16_ARITHMETIC_SUPPORTED() 0
#endif
// Unlike the float16_t family of types, bfloat16_t is not available
// when we're not targeting bfloat16 hardware support on some
// platforms (but not Mac, so we have to be careful not to shadow the
@ -352,18 +359,72 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
other, &Vectorized<float>::name); \
}
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
Vectorized frac() const;
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt)
#ifdef __ARM_FEATURE_BF16
// Flip sign bit
Vectorized<c10::BFloat16> neg() const {
return vreinterpretq_bf16_s16(vreinterpretq_s16_bf16(values) ^ (-32768));
}
// Fast reciprocal is fine because we are truncating results
Vectorized<c10::BFloat16> reciprocal() const {
auto x = vcvtq_low_f32_bf16(values);
auto y = vcvtq_high_f32_bf16(values);
x = vrecpeq_f32(x);
y = vrecpeq_f32(y);
return vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(x), y);
}
// Clearing the sign bit
Vectorized<c10::BFloat16> abs() const {
return vreinterpretq_bf16_u16(vreinterpretq_u16_bf16(values) & 0x7FFF);
}
#else
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal)
#endif
// These functions are optimized on clang-21+
#if BF16_ARITHMETIC_SUPPORTED() && (__clang_major__ >= 21)
Vectorized<c10::BFloat16> operator==(
const Vectorized<c10::BFloat16>& other) const {
return values == other.values;
}
Vectorized<c10::BFloat16> operator!=(
const Vectorized<c10::BFloat16>& other) const {
return values != other.values;
}
Vectorized<c10::BFloat16> operator<(
const Vectorized<c10::BFloat16>& other) const {
return values < other.values;
}
Vectorized<c10::BFloat16> operator<=(
const Vectorized<c10::BFloat16>& other) const {
return values <= other.values;
}
Vectorized<c10::BFloat16> operator>(
const Vectorized<c10::BFloat16>& other) const {
return values > other.values;
}
Vectorized<c10::BFloat16> operator>=(
const Vectorized<c10::BFloat16>& other) const {
return values >= other.values;
}
#else
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=)
#endif
#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
@ -412,28 +473,52 @@ template <>
Vectorized<c10::BFloat16> inline operator+(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
#if BF16_ARITHMETIC_SUPPORTED()
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x + y;
#else
return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
#endif
}
template <>
Vectorized<c10::BFloat16> inline operator-(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
#if BF16_ARITHMETIC_SUPPORTED()
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x - y;
#else
return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
#endif
}
template <>
Vectorized<c10::BFloat16> inline operator*(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
#if BF16_ARITHMETIC_SUPPORTED()
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x * y;
#else
return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
#endif
}
template <>
Vectorized<c10::BFloat16> inline operator/(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
#if BF16_ARITHMETIC_SUPPORTED()
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x / y;
#else
return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
#endif
}
// frac. Implement this here so we can use subtraction
@ -544,12 +629,19 @@ Vectorized<c10::BFloat16> inline fmadd(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) {
#if BF16_ARITHMETIC_SUPPORTED()
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return x * y + z;
#else
// NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16! Also,
// vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered
// elements, not the bottom and top half, so they don't seem
// particularly useful here. Ideally we would include dot product in
// the Vectorized interface...
return a * b + c;
#endif
}
template <>
@ -557,8 +649,15 @@ Vectorized<c10::BFloat16> inline fnmadd(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) {
#if BF16_ARITHMETIC_SUPPORTED()
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return (-x) * y + z;
#else
// See NOTE [BF16 FMA] above.
return -a * b + c;
#endif
}
template <>
@ -566,8 +665,15 @@ Vectorized<c10::BFloat16> inline fmsub(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) {
#if BF16_ARITHMETIC_SUPPORTED()
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return x * y - z;
#else
// See NOTE [BF16 FMA] above.
return a * b - c;
#endif
}
template <>
@ -575,8 +681,15 @@ Vectorized<c10::BFloat16> inline fnmsub(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) {
#if BF16_ARITHMETIC_SUPPORTED()
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return (-x) * y - z;
#else
// See NOTE [BF16 FMA] above.
return -a * b - c;
#endif
}
#endif // !defined(C10_MOBILE) && defined(__aarch64__)

View File

@ -6,9 +6,9 @@ namespace at::vec {
inline namespace CPU_CAPABILITY {
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
// Enable auto-vectorization for GCC-13+ and clang-17+
// Enable auto-vectorization for clang-17+
// GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001
#if __GNUC__ > 12 || (defined(__clang__) && (__clang_major__ >= 17))
#if defined(__clang__) && (__clang_major__ >= 17)
template <typename from_type, typename to_type>
inline void convertImpl(

View File

@ -309,7 +309,7 @@ class Vectorized<float> {
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1)
// Implementation copied from Arm Optimized Routine
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
Vectorized<float> exp_u20() const {
inline Vectorized<float> vexpq_f32_u20() const {
// bail out to sleef if it's a special case:
// i.e. there's an input s.t. |input| > 87.3....
const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f);
@ -348,6 +348,9 @@ class Vectorized<float> {
return vfmaq_f32(scale, poly, scale);
}
Vectorized<float> exp_u20() const {
return vexpq_f32_u20();
}
Vectorized<float> fexp_u20() const {
return exp_u20();
}
@ -634,7 +637,7 @@ inline Vectorized<float> Vectorized<float>::erf() const {
// - exp(- x * x)
auto pow_2 = (*this) * (*this);
auto neg_pow_2 = pow_2 ^ neg_zero_vec;
auto tmp4 = neg_pow_2.exp();
auto tmp4 = neg_pow_2.vexpq_f32_u20();
auto tmp5 = tmp4 ^ neg_zero_vec;
// erf(x) = sign(x) * (1 - r * t * exp(- x * x))
auto tmp6 = t * tmp5;

View File

@ -1,78 +1,90 @@
#include <ATen/cuda/CUDAGreenContext.h>
namespace at::cuda {
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
#if CUDA_HAS_GREEN_CONTEXT
int driver_version;
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
TORCH_CHECK(
driver_version >= 12080, "cuda driver too old to use green context!");
CUcontext pctx = nullptr;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
if (C10_UNLIKELY(!pctx)) {
TORCH_WARN(
"Attempted to create a green context but"
" there was no primary context! Creating a primary context...");
cudaFree(0);
}
CUdevice device;
device_id_ = device_id;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
// Get device resources
CUdevResource device_resource;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
// Split resources
std::vector<CUdevResource> result(1);
auto result_data = result.data();
unsigned int nb_groups = 1;
CUdevResource remaining;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
result_data,
&nb_groups,
&device_resource,
&remaining,
0, // default flags
num_sms));
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
// Generate resource descriptor
CUdevResourceDesc desc;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
&desc, result_data, 1));
// Create green context
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
// Convert to regular context
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#include <stdexcept>
#include <vector>
#define HAS_CUDA_GREEN_CONTEXT() 1
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#define HAS_CUDA_GREEN_CONTEXT() 0
// Suppress unsued private field warnings as this class is not supposed to be called
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-private-field")
#endif
namespace at::cuda {
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
#if HAS_CUDA_GREEN_CONTEXT()
int driver_version;
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
TORCH_CHECK(
driver_version >= 12080, "cuda driver too old to use green context!");
CUcontext pctx = nullptr;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
if (C10_UNLIKELY(!pctx)) {
TORCH_WARN(
"Attempted to create a green context but"
" there was no primary context! Creating a primary context...");
cudaFree(0);
}
CUdevice device;
device_id_ = device_id;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
// Get device resources
CUdevResource device_resource;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
// Split resources
std::vector<CUdevResource> result(1);
auto result_data = result.data();
unsigned int nb_groups = 1;
CUdevResource remaining;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
result_data,
&nb_groups,
&device_resource,
&remaining,
0, // default flags
num_sms));
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
// Generate resource descriptor
CUdevResourceDesc desc;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
&desc, result_data, 1));
// Create green context
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
// Convert to regular context
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
std::unique_ptr<GreenContext> GreenContext::create(
uint32_t num_sms,
std::optional<uint32_t> device_id) {
#if CUDA_HAS_GREEN_CONTEXT
#if HAS_CUDA_GREEN_CONTEXT()
if (!device_id.has_value()) {
device_id = at::cuda::current_device();
}
return std::make_unique<GreenContext>(device_id.value(), num_sms);
return std::unique_ptr<GreenContext>(new GreenContext(device_id.value(), num_sms));
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
@ -80,7 +92,7 @@ namespace at::cuda {
// Implement move operations
GreenContext::GreenContext(GreenContext&& other) noexcept{
#if CUDA_HAS_GREEN_CONTEXT
#if HAS_CUDA_GREEN_CONTEXT()
device_id_ = std::exchange(other.device_id_, -1);
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
context_ = std::exchange(other.context_, nullptr);
@ -91,7 +103,7 @@ namespace at::cuda {
}
GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
#if CUDA_HAS_GREEN_CONTEXT
#if HAS_CUDA_GREEN_CONTEXT()
if (this != &other) {
// Clean up current resources
if (green_ctx_) {
@ -120,7 +132,7 @@ namespace at::cuda {
}
GreenContext::~GreenContext() noexcept{
#if CUDA_HAS_GREEN_CONTEXT
#if HAS_CUDA_GREEN_CONTEXT()
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
#else
@ -128,25 +140,9 @@ namespace at::cuda {
#endif
}
// Get the underlying CUDA context
CUcontext GreenContext::getContext() const {
#if CUDA_HAS_GREEN_CONTEXT
return context_;
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
// Get the underlying green context
#if CUDA_HAS_GREEN_CONTEXT
CUgreenCtx GreenContext::getGreenContext() const {
return green_ctx_;
}
#endif
// Make this context current
void GreenContext::setContext() {
#if CUDA_HAS_GREEN_CONTEXT
#if HAS_CUDA_GREEN_CONTEXT()
auto current_stream = c10::cuda::getCurrentCUDAStream();
parent_stream_ = current_stream.stream();
@ -175,7 +171,7 @@ namespace at::cuda {
}
void GreenContext::popContext() {
#if CUDA_HAS_GREEN_CONTEXT
#if HAS_CUDA_GREEN_CONTEXT()
// see above note about stream being hardcoded to the default stream
at::cuda::CUDAEvent ev;
ev.record(c10::cuda::getCurrentCUDAStream());

View File

@ -1,53 +1,38 @@
#pragma once
#include <ATen/cuda/CUDAEvent.h>
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#include <cuda.h>
#include <memory>
#include <stdexcept>
#include <vector>
#define CUDA_HAS_GREEN_CONTEXT 1
#else
#define CUDA_HAS_GREEN_CONTEXT 0
#endif
// Forward declare green context as opaque ptr
typedef struct CUgreenCtx_st* CUgreenCtx;
namespace at::cuda {
class TORCH_CUDA_CPP_API GreenContext {
public:
GreenContext(uint32_t device_id, uint32_t num_sms);
static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
// Green context creation
static std::unique_ptr<GreenContext> create(
uint32_t num_sms,
std::optional<uint32_t> device_id);
~GreenContext() noexcept;
// Delete copy constructor and assignment
GreenContext(const GreenContext&) = delete;
GreenContext& operator=(const GreenContext&) = delete;
// Implement move operations
GreenContext(GreenContext&& other) noexcept;
GreenContext& operator=(GreenContext&& other) noexcept;
~GreenContext() noexcept;
// Get the underlying CUDA context
CUcontext getContext() const;
// Get the underlying green context
#if CUDA_HAS_GREEN_CONTEXT
CUgreenCtx getGreenContext() const;
#endif
// Make this context current
void setContext();
void popContext();
private:
#if CUDA_HAS_GREEN_CONTEXT
GreenContext(uint32_t device_id, uint32_t num_sms);
// Implement move operations
GreenContext(GreenContext&& other) noexcept;
GreenContext& operator=(GreenContext&& other) noexcept;
int32_t device_id_ = -1;
CUgreenCtx green_ctx_ = nullptr;
CUcontext context_ = nullptr;
cudaStream_t parent_stream_ = nullptr;
#endif
};
} // namespace at::cuda

View File

@ -7,17 +7,6 @@
#endif
#if defined(USE_ROCM)
// hipSparse const API added in v2.4.0
#if HIPSPARSE_VERSION >= 200400
#define AT_USE_HIPSPARSE_GENERIC_API() 1
#else
#define AT_USE_HIPSPARSE_GENERIC_API() 1
#endif
#else // USE_ROCM
#define AT_USE_HIPSPARSE_GENERIC_API() 0
#endif // USE_ROCM
// cuSparse Generic API spsv function was added in CUDA 11.3.0
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500)
#define AT_USE_CUSPARSE_GENERIC_SPSV() 1

View File

@ -1,5 +1,6 @@
#pragma once
#include <c10/core/CachingDeviceAllocator.h>
#include <c10/core/Device.h>
#include <c10/util/Exception.h>
@ -151,6 +152,36 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
}
virtual bool isAvailable() const override;
/* MTIAGraph related APIs */
virtual int64_t mtiagraphCreate(bool keep_graph = false) const {
FAIL_MTIAHOOKS_FUNC(__func__);
return -1;
}
virtual void mtiagraphCaptureBegin(int64_t handle, MempoolId_t pool) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual void mtiagraphCaptureEnd(int64_t handle) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual void mtiagraphInstantiate(int64_t handle) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual void mtiagraphReplay(int64_t handle) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual void mtiagraphReset(int64_t handle) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual MempoolId_t mtiagraphPool(int64_t handle) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
};
struct TORCH_API MTIAHooksArgs {};

View File

@ -410,8 +410,8 @@ struct ConvParams {
return false;
}
static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
// broken on cuDNN 9.8
if (cudnn_version >= 90800) {
// broken on cuDNN 9.8 - 9.14
if (cudnn_version >= 90800 && cudnn_version < 91500) {
if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous &&
(input.scalar_type() == at::kBFloat16 || input.scalar_type() == at::kHalf) &&
weight.dim() == 5) {

View File

@ -170,10 +170,14 @@ static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const
#if defined(CUDA_VERSION) || defined(USE_ROCM)
const auto scalar_type = mat1.scalar_type();
return (beta.toComplexDouble() == 1.0
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
// is to use lt interface only when self is bias.
&& self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous()
&& result.dim() == 2 && result.is_contiguous()
// Conditions for bias to be fusable
&& (
self.is_contiguous() &&
// NOTE: fine to have 1-len dims to the left from the right-most one
(self.dim() == 1 || self.squeeze().dim() == 1) &&
self.sizes().back() == mat2_sizes[1]
)
&& ( // some dtype restrictions
#ifndef USE_ROCM
scalar_type == at::ScalarType::Double ||

View File

@ -213,9 +213,9 @@ _f4_f4_bf16_grouped_mm_fbgemm(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& global_scale_a,
const std::optional<Tensor>& global_scale_a,
const Tensor& scale_b,
const Tensor& global_scale_b,
const std::optional<Tensor>& global_scale_b,
const std::optional<Tensor>& offs,
const std::optional<Tensor>& bias,
Tensor& out) {
@ -225,14 +225,28 @@ _f4_f4_bf16_grouped_mm_fbgemm(
"mat_a must be Float4_e2n1fn_2, got: ", mat_a.scalar_type());
TORCH_CHECK_VALUE(mat_b.scalar_type() == at::kFloat4_e2m1fn_x2,
"mat_b must be Float4_e2n1fn_2, got: ", mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e4m3fn,
"scale_a must be Float8_e4m3fn, got: ", scale_a.scalar_type());
TORCH_CHECK_VALUE(scale_b.scalar_type() == at::kFloat8_e4m3fn,
"scale_b must be Float8_e4m3fn, got: ", scale_b.scalar_type());
TORCH_CHECK_VALUE(global_scale_a.scalar_type() == at::kFloat,
"global_scale_a must be Float, got: ", global_scale_a.scalar_type());
TORCH_CHECK_VALUE(global_scale_b.scalar_type() == at::kFloat,
"global_scale_b must be Float, got: ", global_scale_b.scalar_type());
std::optional<Tensor> combined_global_scale = std::nullopt;
if (global_scale_a.has_value() || global_scale_b.has_value()) {
// NVFP4
TORCH_CHECK_VALUE(global_scale_a.has_value() && global_scale_b.has_value(),
"For NVFP4 grouped gemm both of global_scale_{a,b} must have values")
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e4m3fn,
"scale_a must be Float8_e4m3fn, got: ", scale_a.scalar_type());
TORCH_CHECK_VALUE(scale_b.scalar_type() == at::kFloat8_e4m3fn,
"scale_b must be Float8_e4m3fn, got: ", scale_b.scalar_type());
TORCH_CHECK_VALUE(global_scale_a.value().scalar_type() == at::kFloat,
"global_scale_a must be Float, got: ", global_scale_a.value().scalar_type());
TORCH_CHECK_VALUE(global_scale_b.value().scalar_type() == at::kFloat,
"global_scale_b must be Float, got: ", global_scale_b.value().scalar_type());
combined_global_scale = global_scale_a.value().mul(global_scale_b.value());
} else {
// MXFP4
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu,
"scale_a must be Float8_e8m0fnu, got: ", scale_a.scalar_type());
TORCH_CHECK_VALUE(scale_b.scalar_type() == at::kFloat8_e8m0fnu,
"scale_b must be Float8_e8m0fnu, got: ", scale_b.scalar_type());
}
auto o = fbgemm_gpu::f4f4bf16_grouped_mm(
mat_a,
@ -241,7 +255,7 @@ _f4_f4_bf16_grouped_mm_fbgemm(
scale_b,
offs.value(),
out,
global_scale_a.mul(global_scale_b)
combined_global_scale
);
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "nvfp4 grouped gemm is not supported without USE_FBGEMM_GENAI, and only for CUDA")
@ -471,9 +485,10 @@ namespace {
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 3> scale_grouped_kernel_dispatch = {{
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 4> scale_grouped_kernel_dispatch = {{
{ "rowwise_rowwise", scaled_blas::check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8},
{ "mxfp4_mxfp4", scaled_blas::check_mxfp4_recipe, ScaledGemmImplementation::MXFP4_MXFP4},
{ "nvfp4_nvfp4", scaled_blas::check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4}}};
} // anonymous namespace
@ -599,6 +614,21 @@ _scaled_grouped_mm_cuda_v2(
offs.value(),
out);
}
case ScaledGemmImplementation::MXFP4_MXFP4: {
// scale shape checks
_check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
_check_scales_blocked(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
return _f4_f4_bf16_grouped_mm_fbgemm(
mat_a,
mat_b,
scale_a[0], /* block-scale A */
std::nullopt, /* global-scale A */
scale_b[0], /* block-scale B */
std::nullopt, /* global-scale B */
offs.value(),
std::nullopt, /* bias */
out);
}
case ScaledGemmImplementation::NVFP4_NVFP4: {
// scale shape checks
_check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);

View File

@ -13,7 +13,7 @@ __global__ void vectorized_gather_kernel(char * out, char * inp, index_t * idx,
if (allow_neg_indices) {
ind = (ind < 0) ? ind + ind_dim_size : ind;
}
CUDA_KERNEL_ASSERT(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds");
CUDA_KERNEL_ASSERT_VERBOSE(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds", "Expected 0 <= index < ind_dim_size(%ld), but got index = %ld", ind_dim_size, ind);
int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; // off is guaranteed to be within int32 limits
if (off >= slice_size) return;
auto vec = at::native::memory::ld_vec<Alignment>(inp + ind * inp_stride + off);

View File

@ -59,6 +59,22 @@
// forward declare
class cublasCommonArgs;
namespace fbgemm_gpu {
// NOTE(slayton58): FBGemm_GPU kernels come from <fbgemm_gpu/torch_ops.h> within the FBGemm repo.
// To update supported ops means a submodule bump, which is.. painful. Instead, we
// can simply forward-declare the methods we want to use.. Works at least as a short-term
// thing, but should still be fixed somewhere/somehow.
at::Tensor f4f4bf16(
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
std::optional<at::Tensor>,
bool use_mx);
} // namespace fbgemm_gpu
using at::blas::ScalingType;
using at::blas::SwizzleType;
@ -794,6 +810,24 @@ void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const Sc
}
}
void
_check_deepseek_support() {
#ifndef USE_ROCM
auto dprops = at::cuda::getCurrentDeviceProperties();
if (dprops->major != 9) {
// Only on Hopper GPUs
TORCH_CHECK_NOT_IMPLEMENTED(
dprops->major == 9,
"DeepSeek style (1x128, 128x128) scaling only supported in CUDA for SM90")
}
// Only in cublasLt >= 12.9
TORCH_CHECK_NOT_IMPLEMENTED(
CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900,
"DeepSeek style (1x128, 128x128) scaling requires cublasLt >= 12.9"
);
#endif
}
Tensor&
_scaled_block1x128_block1x128(
const Tensor& mat_a, const Tensor& mat_b,
@ -802,8 +836,12 @@ _scaled_block1x128_block1x128(
const c10::ScalarType out_dtype,
const bool use_fast_accum,
Tensor& out) {
#ifndef USE_ROCM
// Restrictions:
// A, B are FP8, scales are fp32, shape K//128
// CUDA: Only Hopper GPUs
_check_deepseek_support();
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
@ -821,6 +859,12 @@ _scaled_block1x128_block1x128(
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
#else
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"1x128 and 128x128 scaling not available with ROCm"
);
#endif
}
Tensor&
@ -831,10 +875,12 @@ _scaled_block128x128_block1x128(
const c10::ScalarType out_dtype,
const bool use_fast_accum,
Tensor& out) {
#ifndef USE_ROCM
// Restrictions:
// A, B are FP8, scales are fp32, shape K//128
std::cout << "mat_b: " << mat_b.dim() << ", " << mat_b.sizes() << ", " << mat_b.strides() << std::endl;
std::cout << "scale_b: " << scale_b.dim() << ", " << scale_b.sizes() << ", " << scale_b.strides() << std::endl;
// CUDA: Only Hopper GPUs
_check_deepseek_support();
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat,
@ -852,6 +898,12 @@ _scaled_block128x128_block1x128(
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
#else
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"1x128 and 128x128 scaling not available with ROCm"
);
#endif
}
Tensor&
@ -862,8 +914,12 @@ _scaled_block1x128_block128x128(
const c10::ScalarType out_dtype,
const bool use_fast_accum,
Tensor& out) {
#ifndef USE_ROCM
// Restrictions:
// A, B are FP8, scales are fp32, A: shape K//128, B: K//128, N//128
// CUDA: Only Hopper GPUs
_check_deepseek_support();
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
@ -881,6 +937,12 @@ _scaled_block1x128_block128x128(
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
#else
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"1x128 and 128x128 scaling not available with ROCm"
);
#endif
}
Tensor&
@ -951,26 +1013,47 @@ _scaled_mxfp4_mxfp4(
const std::optional<Tensor>& bias,
const c10::ScalarType out_dtype,
Tensor& out) {
#ifndef USE_ROCM
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only");
#if !defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI)
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
#endif
// Restrictions:
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
auto scale_a_elems = ceil_div<int64_t>(2 * mat_a.size(0), 32) * mat_a.size(1);
auto scale_b_elems = ceil_div<int64_t>(2 * mat_b.size(1), 32) * mat_b.size(0);
// Packed FP4 format means actual-K = 2 * reported-K -- adjust
auto K_multiplier = 2;
#ifdef USE_ROCM
// AMD
auto scale_a_elems = ceil_div<int64_t>(K_multiplier * mat_a.size(0), 32) * mat_a.size(1);
auto scale_b_elems = ceil_div<int64_t>(K_multiplier * mat_b.size(1), 32) * mat_b.size(0);
#else
// NVIDIA
auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_a.size(1), 32), 4);
auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_b.size(0), 32), 4);
#endif
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
#ifdef USE_ROCM
// AMD
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE, "scale_a must not be swizzled (NO_SWIZZLE format)");
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::NO_SWIZZLE, "scale_b must not be swizzled (NO_SWIZZLE format)");
#else
// NVIDIA
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format");
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format");
#endif
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
"For Blockwise scaling both scales should be contiguous");
TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype);
#ifdef USE_ROCM
// AMD
auto scaling_choice_a = ScalingType::BlockWise1x32;
auto scaling_choice_b = ScalingType::BlockWise1x32;
@ -985,11 +1068,29 @@ _scaled_mxfp4_mxfp4(
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
out.scalar_type() == ScalarType::Half,
"Block-wise scaling only supports BFloat16 or Half output types");
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
#endif
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
#else
// NVIDIA
// NOTE(slayton58): fbgemm_gpu::f4f4bf16 does *not* allow passing an output tensor,
// but we have one we need to use. Two clear options are to copy into
// our output (slow), or use a move-assignment-operator (faster).
// However, the compiler can complain about the explicit move preventing
// copy elision because the return from f4f4bf16 is a temporary object.
// So we don't explicitly move, and trust the compiler here...
// In the longer term this should be fixed on the FBGemm side.
out = fbgemm_gpu::f4f4bf16(
mat_a,
mat_b.transpose(-2, -1),
scale_a,
scale_b,
std::nullopt, /* global_scale */
true /* use_mx */
);
return out;
#endif
}
Tensor&
@ -1114,17 +1215,20 @@ _scaled_mm_cuda_v2_out(
mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ")");
}
// Handle fp4 packed-K dimension
int K_multiplier = (mat_a.scalar_type() == ScalarType::Float4_e2m1fn_x2) ? 2 : 1;
TORCH_CHECK_VALUE(!bias || bias->numel() == mat_b.sizes()[1], "Bias must be size ", mat_b.sizes()[1],
" but got ", bias->numel());
TORCH_CHECK_VALUE(
mat_a.sizes()[1] % 16 == 0,
K_multiplier * mat_a.sizes()[1] % 16 == 0,
"Expected trailing dimension of mat1 to be divisible by 16 ",
"but got mat1 shape: (",
mat_a.sizes()[0],
"x",
mat_a.sizes()[1],
K_multiplier * mat_a.sizes()[1],
").");
TORCH_CHECK_VALUE(mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
TORCH_CHECK_VALUE(K_multiplier * mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
mat_b.sizes()[1], ") must be divisible by 16");
// TODO(slayton): Existing checks, not sure if they should really be here.

View File

@ -160,8 +160,8 @@ struct _cuda_scatter_gather_internal_kernel {
auto offsets = offset_calc.get(i);
int64_t idx_dim = *(index_t*)(index_ptr + offsets[2]);
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
&& "scatter gather kernel index out of bounds");
CUDA_KERNEL_ASSERT_VERBOSE(idx_dim >= 0 && idx_dim < index_size
&& "scatter gather kernel index out of bounds", "Expected 0 <= idx_dim < index_size (%ld), but got idx_dim = %ld", index_size, idx_dim);
f(
(scalar_t*)(self_ptr + offsets[0]),
@ -406,9 +406,8 @@ struct _cuda_scatter_fill_internal_kernel {
auto offsets = offset_calc.get(i);
int64_t idx_dim = *(index_t*)(index_ptr + offsets[1]);
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
&& "index out of bounds"
);
CUDA_KERNEL_ASSERT_VERBOSE(idx_dim >= 0 && idx_dim < index_size
&& "index out of bounds", "Expected 0 <= idx_dim < index_size (%ld), but got idx_dim = %ld", index_size, idx_dim);
f(
(scalar_t*)(self_ptr + offsets[0]),

View File

@ -141,7 +141,8 @@ WelfordDataLN cuWelfordOnlineSum(
if constexpr (!rms_norm){
U delta = val - curr_sum.mean;
U new_count = curr_sum.count + 1.f;
#if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
//Due to low CU count, we run into accuracy issues on gfx90a with `__builtin_amdgcn_rcpf`
#if defined(USE_ROCM) && !defined(__gfx90a__) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
U new_mean = curr_sum.mean + delta * __builtin_amdgcn_rcpf(new_count);
#else
U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster
@ -163,7 +164,8 @@ WelfordDataLN cuWelfordCombine(
U count = dataA.count + dataB.count;
U mean, sigma2;
if (count > decltype(dataB.count){0}) {
#if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
//Due to low CU count, we run into accuracy issues on gfx90a with `__builtin_amdgcn_rcpf`
#if defined(USE_ROCM) && !defined(__gfx90a__) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
auto coef = __builtin_amdgcn_rcpf(count);
#else
auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division

View File

@ -40,14 +40,37 @@ bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) {
return true;
}
bool check_no_grad(sdp::sdp_params const& params, bool debug) {
const bool any_inputs_require_grad = params.query.requires_grad() ||
params.key.requires_grad() || params.value.requires_grad();
const bool gradmode_enabled = at::GradMode::is_enabled();
if (debug && any_inputs_require_grad && gradmode_enabled) {
TORCH_WARN("Backward or grad to be supported.");
bool input_require_grad(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const std::optional<at::Tensor>& attn_mask) {
return at::GradMode::is_enabled() &&
(query.requires_grad() || key.requires_grad() || value.requires_grad() ||
(attn_mask.has_value() && attn_mask.value().requires_grad()));
}
bool check_grad(sdp::sdp_params const& params, bool debug) {
if (!input_require_grad(
params.query, params.key, params.value, params.attn_mask))
return true;
auto q_num_heads = params.query.sym_size(-3);
auto k_num_heads = params.key.sym_size(-3);
auto v_num_heads = params.value.sym_size(-3);
bool is_gqa = q_num_heads != k_num_heads || q_num_heads != v_num_heads;
if (debug && is_gqa)
TORCH_WARN(
"scale_dot_product_attention with gqa is not supported for gradient computation on xpu.");
bool attn_mask_needs_grad =
params.attn_mask.has_value() && params.attn_mask.value().requires_grad();
if (debug && attn_mask_needs_grad) {
TORCH_WARN(
"scale_dot_product_attention on xpu is not supported when attn_mask.requires_grad() == True.");
}
return !any_inputs_require_grad || !gradmode_enabled;
return !is_gqa && !attn_mask_needs_grad;
}
bool can_use_overrideable_attention(sdp::sdp_params const& params, bool debug) {
@ -65,7 +88,7 @@ bool can_use_overrideable_attention(sdp::sdp_params const& params, bool debug) {
sdp::check_nonzero_sequence_lengths_dense,
sdp::check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim*/>,
check_head_dim_size_xpu,
check_no_grad);
check_grad);
for (auto& constraint : constraints) {
if (!constraint(params, debug)) {
return false;
@ -225,10 +248,11 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale) {
std::optional<double> scale,
bool compute_logsumexp) {
TORCH_INTERNAL_ASSERT(
query.dim() == 4 && key.dim() == 4 && value.dim() == 4,
"scaled_dot_product_fused_attention_overrideable_xpu: Accept only 4 dims inputs shape of {(B), H, T, K}");
"scaled_dot_product_fused_attention_overrideable_xpu: Accept only 4 dims inputs shape of {B, H, T, K}");
TORCH_INTERNAL_ASSERT(
(key.size(0) == value.size(0)) && (key.size(1) == value.size(1)) &&
(key.size(2) == value.size(2)),
@ -245,6 +269,9 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
TORCH_INTERNAL_ASSERT(
!(attn_bias.has_value() && is_causal),
"scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot present with is_causal");
TORCH_INTERNAL_ASSERT(
!(attn_bias.has_value() && attn_bias.value().requires_grad()),
"scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot have requires_grad=True");
const int64_t batch_size = query.size(0);
const int64_t num_head_q = query.size(1);
@ -254,11 +281,14 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
const int64_t seq_len_q = query.size(2);
const int64_t seq_len_kv = key.size(2);
at::Tensor output;
std::vector<int64_t> output_shape = {
at::Tensor attention;
std::vector<int64_t> attention_shape = {
batch_size, num_head_q, seq_len_q, head_dim_v};
alloc_with_matching_layout(query, output, output_shape);
at::Tensor logsumexp, debug_attn_mask; // not supported
alloc_with_matching_layout(query, attention, attention_shape);
auto opts = query.options();
at::Tensor logsumexp =
at::empty({batch_size, num_head_q, seq_len_q}, opts.dtype(at::kFloat));
at::native::onednn::sdpa(
batch_size,
@ -274,15 +304,15 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
attn_bias,
is_causal,
scale.has_value() ? scale.value() : (1.0 / std::sqrt(head_dim_qk)),
output,
false,
attention,
compute_logsumexp,
logsumexp);
// rng not used
auto philox_seed = at::empty({}, at::dtype(at::kLong));
auto philox_offset = at::empty({}, at::dtype(at::kLong));
return std::make_tuple(
output,
attention,
logsumexp,
/* cum_seq_q */ at::Tensor(),
/* cum_seq_k */ at::Tensor(),
@ -290,7 +320,106 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
seq_len_kv,
philox_seed,
philox_offset,
debug_attn_mask);
/*debug_attn_mask */ at::Tensor());
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_scaled_dot_product_fused_attention_overrideable_backward_xpu(
const at::Tensor& grad_out,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const at::Tensor& attn_bias,
std::array<bool, 4> grad_input_mask,
const at::Tensor& out,
const at::Tensor& logsumexp,
const at::Tensor& cum_seq_q,
const at::Tensor& cum_seq_k,
int64_t max_q,
int64_t max_k,
double dropout_p,
bool is_causal,
const at::Tensor& philox_seed,
const at::Tensor& philox_offset,
std::optional<double> scale) {
TORCH_INTERNAL_ASSERT(
grad_out.dim() == 4 && out.dim() == 4 &&
grad_out.size(0) == out.size(0) && grad_out.size(1) == out.size(1) &&
grad_out.size(2) == out.size(2) && grad_out.size(3) == out.size(3),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: grad_out and out should have the same shape of {B, H, T, K}");
TORCH_INTERNAL_ASSERT(
query.dim() == 4 && key.dim() == 4 && value.dim() == 4,
"scaled_dot_product_fused_attention_overrideable_backward_xpu: Accept only 4 dims inputs shape of {B, H, T, K}");
TORCH_INTERNAL_ASSERT(
(key.size(0) == value.size(0)) && (key.size(1) == value.size(1)) &&
(key.size(2) == value.size(2)),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: K/V should have the same batch / seq / num_head");
TORCH_INTERNAL_ASSERT(
query.size(0) == grad_out.size(0) && query.size(1) == grad_out.size(1) &&
query.size(2) == grad_out.size(2),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: Q should have the same batch / num_head / seq_len as grad_out");
TORCH_INTERNAL_ASSERT(
query.size(3) == key.size(3),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: Q/K should have the same head_dim");
TORCH_INTERNAL_ASSERT(
value.size(3) == grad_out.size(3),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: V should have the same head_dim as grad_out");
TORCH_INTERNAL_ASSERT(
query.size(1) == key.size(1),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: number of heads in K/V must equal to number of heads in Q");
TORCH_INTERNAL_ASSERT(
dropout_p == 0.0,
"scaled_dot_product_fused_attention_overrideable_backward_xpu: Currently do not support dropout > 0");
TORCH_INTERNAL_ASSERT(
logsumexp.dim() == 3 && logsumexp.size(0) == query.size(0) &&
logsumexp.size(1) == query.size(1) &&
logsumexp.size(2) == query.size(2) &&
"scaled_dot_product_fused_attention_overrideable_backward_xpu: logsumexp should have the shape of {B, H, T}");
std::optional<Tensor> attn_bias_opt;
if (attn_bias.defined()) {
attn_bias_opt = attn_bias;
}
const int64_t batch_size = query.size(0);
const int64_t num_head_q = query.size(1);
const int64_t num_head_kv = key.size(1);
const int64_t seq_len_q = query.size(2);
const int64_t seq_len_kv = key.size(2);
const int64_t head_dim_qk = query.size(3);
const int64_t head_dim_v = value.size(3);
auto grad_q = at::empty_like(query);
auto grad_k = at::empty_like(key);
auto grad_v = at::empty_like(value);
auto grad_attn_bias = attn_bias_opt.has_value()
? at::empty_like(attn_bias_opt.value())
: at::Tensor();
at::native::onednn::sdpa_backward(
batch_size,
num_head_q,
num_head_kv,
seq_len_q,
seq_len_kv,
head_dim_qk,
head_dim_v,
grad_out,
query,
key,
value,
out,
logsumexp,
attn_bias_opt,
is_causal,
scale.has_value() ? scale.value() : (1.0 / std::sqrt(query.size(3))),
grad_q,
grad_k,
grad_v);
return std::make_tuple(
std::move(grad_q),
std::move(grad_k),
std::move(grad_v),
std::move(grad_attn_bias));
}
REGISTER_XPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_xpu);

View File

@ -86,6 +86,28 @@ struct zeta_functor {
}
};
struct logaddexp_functor {
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
inline T operator()(const T a, const T b) {
return c10::metal::logaddexp(a, b);
}
template <typename T, enable_if_t<is_integral_v<T>, bool> = true>
inline float operator()(const T a, const T b) {
return c10::metal::logaddexp(float(a), float(b));
}
};
struct logaddexp2_functor {
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
inline T operator()(const T a, const T b) {
return c10::metal::logaddexp2(a, b);
}
template <typename T, enable_if_t<is_integral_v<T>, bool> = true>
inline float operator()(const T a, const T b) {
return c10::metal::logaddexp2(float(a), float(b));
}
};
struct xlog1py_functor {
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
inline T operator()(const T a, const T b) {
@ -377,6 +399,10 @@ REGISTER_FLOAT_BINARY_OP(fmin);
REGISTER_FLOAT_BINARY_OP(nextafter);
REGISTER_FLOAT_BINARY_OP(zeta);
REGISTER_INT2FLOAT_BINARY_OP(zeta);
REGISTER_FLOAT_BINARY_OP(logaddexp);
REGISTER_INT2FLOAT_BINARY_OP(logaddexp);
REGISTER_FLOAT_BINARY_OP(logaddexp2);
REGISTER_INT2FLOAT_BINARY_OP(logaddexp2);
REGISTER_FLOAT_BINARY_OP(xlog1py);
REGISTER_INT2FLOAT_BINARY_OP(xlog1py);
REGISTER_FLOAT_BINARY_OP(chebyshev_polynomial_t);
@ -463,6 +489,8 @@ REGISTER_BINARY_OP(add, float2, float2);
REGISTER_BINARY_OP(add, half2, half2);
REGISTER_BINARY_OP(sub, float2, float2);
REGISTER_BINARY_OP(sub, half2, half2);
REGISTER_BINARY_OP(logaddexp, float2, float2);
REGISTER_BINARY_OP(logaddexp, half2, half2);
REGISTER_BINARY_ALPHA_OP(add_alpha, float2, float2, float2);
REGISTER_BINARY_ALPHA_OP(add_alpha, half2, half2, half2);
REGISTER_BINARY_ALPHA_OP(sub_alpha, float2, float2, float2);

View File

@ -89,6 +89,14 @@ static void zeta_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "zeta");
}
static void logaddexp_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "logaddexp");
}
static void logaddexp2_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "logaddexp2");
}
static void xlog1py_mps_kernel(TensorIteratorBase& iter) {
TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "xlog1py_mps not implemented for non-floating types");
lib.exec_binary_kernel(iter, "xlog1py");
@ -211,6 +219,8 @@ REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel)
REGISTER_DISPATCH(copysign_stub, &copysign_mps_kernel)
REGISTER_DISPATCH(nextafter_stub, &nextafter_mps_kernel)
REGISTER_DISPATCH(zeta_stub, &zeta_mps_kernel)
REGISTER_DISPATCH(logaddexp_stub, &logaddexp_mps_kernel);
REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_mps_kernel);
REGISTER_DISPATCH(xlog1py_stub, &xlog1py_mps_kernel)
REGISTER_DISPATCH(chebyshev_polynomial_t_stub, &chebyshev_polynomial_t_mps_kernel)
REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_mps_kernel)

View File

@ -17,8 +17,6 @@
#include <ATen/ops/ge_native.h>
#include <ATen/ops/gt_native.h>
#include <ATen/ops/le_native.h>
#include <ATen/ops/logaddexp2_native.h>
#include <ATen/ops/logaddexp_native.h>
#include <ATen/ops/logical_and_native.h>
#include <ATen/ops/logical_or_native.h>
#include <ATen/ops/logical_xor_native.h>
@ -277,30 +275,6 @@ TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const
}
}
TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor* sumTensor =
[mpsGraph additionWithPrimaryTensor:[mpsGraph exponentWithTensor:primaryCastTensor name:nil]
secondaryTensor:[mpsGraph exponentWithTensor:secondaryCastTensor name:nil]
name:nil];
return [mpsGraph logarithmWithTensor:sumTensor name:nil];
};
mps::binaryOpTensor(self, other, output, "logaddexp_out_mps", logaddexp_op_block);
}
TORCH_IMPL_FUNC(logaddexp2_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
mps::BinaryOpBlock logaddexp2_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor* sumTensor =
[mpsGraph additionWithPrimaryTensor:[mpsGraph exponentBase2WithTensor:primaryCastTensor name:nil]
secondaryTensor:[mpsGraph exponentBase2WithTensor:secondaryCastTensor name:nil]
name:nil];
return [mpsGraph logarithmBase2WithTensor:sumTensor name:nil];
};
mps::binaryOpTensor(self, other, output, "logaddexp2_out_mps", logaddexp2_op_block);
}
TORCH_IMPL_FUNC(xlogy_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
mps::BinaryOpBlock xlogy_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();

View File

@ -1028,15 +1028,18 @@ TORCH_IMPL_FUNC(prod_out_mps)
}
TORCH_IMPL_FUNC(amax_out_mps)(const Tensor& input_t, IntArrayRef dim, bool keepdim, const Tensor& output_t) {
TORCH_CHECK(!c10::isComplexType(input_t.scalar_type()), "amax is not defined for complex types");
reduction_out_mps(input_t, dim, keepdim, std::nullopt, output_t, MPSReductionType::AMAX, "amax_out_mps");
}
TORCH_IMPL_FUNC(amin_out_mps)(const Tensor& input_t, IntArrayRef dim, bool keepdim, const Tensor& output_t) {
TORCH_CHECK(!c10::isComplexType(input_t.scalar_type()), "amin is not defined for complex types");
reduction_out_mps(input_t, dim, keepdim, std::nullopt, output_t, MPSReductionType::AMIN, "amin_out_mps");
}
TORCH_IMPL_FUNC(aminmax_out_mps)
(const Tensor& input_t, std::optional<int64_t> dim_opt, bool keepdim, const Tensor& min_t, const Tensor& max_t) {
TORCH_CHECK(!c10::isComplexType(input_t.scalar_type()), "aminmax is not defined for complex types");
reduction_out_mps(input_t,
dim_opt.has_value() ? OptionalIntArrayRef({*dim_opt}) : std::nullopt,
keepdim,

View File

@ -31,6 +31,7 @@ void kthvalue_out_mps_impl(const Tensor& self, int64_t k, int64_t dim, Tensor& v
indices.copy_(values.toType(at::ScalarType::Long));
return;
}
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), "kthvalue is not implemented for complex types");
// issue #154890, raising error to prevent crash within MPSGraph until
// workaround is implemented.
TORCH_CHECK(self.dim() - dim <= 4, "On-going issue on MPSGraph topk when ndims() - axis > 4, see issue #154890");

View File

@ -3622,8 +3622,7 @@
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: logaddexp_out
MPS: logaddexp_out_mps
CPU, CUDA, MPS: logaddexp_out
tags: pointwise
- func: logaddexp(Tensor self, Tensor other) -> Tensor
@ -3635,8 +3634,7 @@
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: logaddexp2_out
MPS: logaddexp2_out_mps
CPU, CUDA, MPS: logaddexp2_out
tags: pointwise
- func: logaddexp2(Tensor self, Tensor other) -> Tensor
@ -15097,7 +15095,7 @@
CPU: _scaled_dot_product_flash_attention_cpu
tags: nondeterministic_seeded
- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None, bool compute_log_sumexp=True) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
dispatch:
CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable
XPU: _scaled_dot_product_fused_attention_overrideable_xpu
@ -15121,6 +15119,7 @@
variants: function
dispatch:
CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable_backward
XPU: _scaled_dot_product_fused_attention_overrideable_backward_xpu
- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
dispatch:

View File

@ -467,6 +467,28 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values, IntArrayRe
!options.has_layout() || options.layout() == kSparse,
"expected sparse layout, but got layout ",
options.layout());
if (indices.numel() > 0) {
Tensor min_indices =
std::get</* values */ 0>(indices.min(/* dim */ 1, /* keepdim */ false));
Tensor cpu_min_indices;
if (!indices.is_cpu()) {
cpu_min_indices = min_indices.to(at::DeviceType::CPU);
} else {
cpu_min_indices = min_indices;
}
auto cpu_min_indices_accessor = cpu_min_indices.accessor<int64_t, 1>();
for (const auto d : c10::irange(indices.size(0))) {
int64_t min_index_in_dim = cpu_min_indices_accessor[d];
TORCH_CHECK(
min_index_in_dim >= 0,
"found negative index ",
min_index_in_dim,
" for dim ",
d);
}
}
return at::native::_sparse_coo_tensor_unsafe(
indices,
values,

View File

@ -768,8 +768,11 @@ Tensor scaled_dot_product_attention(
return std::get<0>(out_and_lse);
}
case SDPBackend::overrideable: {
bool compute_logsumexp = should_compute_logsumexp(query_, key, value);
compute_logsumexp = compute_logsumexp ||
(at::GradMode::is_enabled() && attn_mask.has_value() && attn_mask.value().requires_grad());
auto out_lse_softmax = at::_scaled_dot_product_fused_attention_overrideable(
query_, key, value, attn_mask, dropout_p, is_causal, false /*return_debug_mask*/, scale);
query_, key, value, attn_mask, dropout_p, is_causal, false /*return_debug_mask*/, scale, compute_logsumexp);
return std::get<0>(out_lse_softmax);
}
case SDPBackend::math: {
@ -1015,7 +1018,8 @@ _scaled_dot_product_fused_attention_overrideable(
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale) {
std::optional<double> scale,
bool compute_logsumexp) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "_scaled_dot_product_fused_attention_overrideable not implemented. This is an operator for privateuse1 backends, please use TORCH_LIBRARY_IMPL to override this function ");
}

View File

@ -1837,6 +1837,10 @@ class BenchmarkRunner:
def skip_models_for_cuda(self):
return set()
@property
def skip_models_for_xpu(self):
return set()
@property
def skip_models_for_cpu(self):
return set()
@ -3927,6 +3931,8 @@ def run(runner, args, original_dir=None):
runner.skip_models.update(runner.skip_models_for_cpu_aarch64)
elif args.devices == ["cuda"]:
runner.skip_models.update(runner.skip_models_for_cuda)
elif args.devices == ["xpu"]:
runner.skip_models.update(runner.skip_models_for_xpu)
if not args.multiprocess:
runner.skip_models.update(runner.skip_multiprocess_models)

View File

@ -124,6 +124,10 @@ class TorchBenchmarkRunner(BenchmarkRunner):
def skip_models_for_cuda(self):
return self._skip["device"]["cuda"]
@property
def skip_models_for_xpu(self):
return self._skip["device"]["xpu"]
@property
def skip_models_for_freezing_cuda(self):
return self._skip["freezing"]["cuda"]

View File

@ -217,6 +217,9 @@ skip:
cuda: []
xpu:
- *DETECTRON2_MODELS
test:
training:
- *DETECTRON2_MODELS

View File

@ -1,4 +1,4 @@
// Implementation of specal math functions for Metal
// Implementation of special math functions for Metal
#pragma once
#include <c10/metal/expm1f.h>
#include <c10/metal/igamma.h>
@ -624,6 +624,64 @@ inline T spherical_bessel_j0(T x) {
return static_cast<T>(::metal::sin(x) / x);
}
template <typename T>
inline ::metal::enable_if_t<is_scalar_floating_point_v<T>, T> logaddexp(
T a,
T b) {
float a0 = static_cast<float>(a);
float b0 = static_cast<float>(b);
if (::metal::isinf(a0) && a0 == b0) {
return static_cast<T>(a0);
} else {
float m0 = ::metal::max(a0, b0);
return static_cast<T>(
m0 + ::c10::metal::log1p(::metal::exp(-::metal::abs(a0 - b0))));
}
}
// The function is ported from mlx
template <typename T>
inline ::metal::enable_if_t<is_complex_v<T>, T> logaddexp(T a, T b) {
if (::metal::isnan(a.x) || ::metal::isnan(a.y) || ::metal::isnan(b.x) ||
::metal::isnan(b.y)) {
return T(NAN, NAN);
}
T maxval = a.x > b.x ? a : b;
T minval = a.x < b.x ? a : b;
constexpr auto inf = ::metal::numeric_limits<T>::infinity().x;
if (minval.x == -inf || maxval.x == inf) {
return maxval;
}
float2 maxval_ = static_cast<float2>(maxval);
float2 minval_ = static_cast<float2>(minval);
float m = ::metal::exp(minval_.x - maxval_.x);
float2 dexp{
m * ::metal::cos(minval_.y - maxval_.y),
m * ::metal::sin(minval_.y - maxval_.y),
};
return static_cast<T>(maxval_ + ::c10::metal::log1p(dexp));
}
template <typename T>
inline T logaddexp2(T a, T b) {
constexpr auto log_2 = float(0.693147180559945309417232121458176);
constexpr auto inv_log_2 = float(1) / log_2;
float a0 = static_cast<float>(a);
float b0 = static_cast<float>(b);
if (::metal::isinf(a0) && a0 == b0) {
return static_cast<T>(a0);
} else {
float m0 = ::metal::max(a0, b0);
return static_cast<T>(
m0 +
::c10::metal::log1p(::metal::pow(float(2), -::metal::abs(a0 - b0))) *
inv_log_2);
}
}
template <typename T>
inline float xlog1py(T x, T y) {
if (::metal::isnan(y)) {

View File

@ -322,6 +322,24 @@ inline float log1p(float x) {
return rc;
}
// The function is ported from mlx
inline float2 log1p(float2 in) {
float x = in.x;
float y = in.y;
float zabs = ::metal::precise::sqrt(x * x + y * y);
float theta = ::metal::atan2(y, x + 1);
if (zabs < 0.5f) {
float r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {0.5f * log1p(r), theta};
} else {
auto z0 = ::metal::sqrt((x + 1) * (x + 1) + y * y);
return {::metal::log(z0), theta};
}
}
template <typename T1, typename T2 = T1>
struct pair {
T1 first;

View File

@ -34,7 +34,7 @@ struct MemEvent {
bool overlaps(const MemBlock& a, const MemBlock& b) {
// two blocks dont overlap if
// |---a--------|--------------b--------|
// strat_a end_a <= start_b end_b
// start_a end_a <= start_b end_b
return !(
(a.end_offset <= b.start_offset) || (b.end_offset <= a.start_offset));
}

View File

@ -33,7 +33,7 @@ struct bitset final {
constexpr bitset() noexcept = default;
constexpr bitset(const bitset&) noexcept = default;
constexpr bitset(bitset&&) noexcept = default;
// there is an issure for gcc 5.3.0 when define default function as constexpr
// there is an issue for gcc 5.3.0 when define default function as constexpr
// see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=68754.
bitset& operator=(const bitset&) noexcept = default;
bitset& operator=(bitset&&) noexcept = default;

View File

@ -554,6 +554,17 @@ class DeviceCachingAllocator {
}
}
double getMemoryFraction() {
if (!set_fraction) {
return 1.0;
}
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device_index);
return static_cast<double>(allowed_memory_maximum) /
static_cast<double>(device_prop.global_mem_size);
}
void setMemoryFraction(double fraction) {
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device_index);
@ -724,6 +735,11 @@ class XPUAllocator : public DeviceAllocator {
device_allocators[device]->resetAccumulatedStats();
}
double getMemoryFraction(DeviceIndex device) {
assertValidDevice(device);
return device_allocators[device]->getMemoryFraction();
}
void setMemoryFraction(double fraction, DeviceIndex device) {
assertValidDevice(device);
TORCH_CHECK_VALUE(
@ -777,6 +793,10 @@ void recordStream(const DataPtr& dataPtr, XPUStream stream) {
return allocator.recordStream(dataPtr, stream);
}
double getMemoryFraction(DeviceIndex device) {
return allocator.getMemoryFraction(device);
}
void setMemoryFraction(double fraction, DeviceIndex device) {
return allocator.setMemoryFraction(fraction, device);
}

View File

@ -25,6 +25,8 @@ C10_XPU_API void raw_delete(void* ptr);
C10_XPU_API void recordStream(const DataPtr& dataPtr, XPUStream stream);
C10_XPU_API double getMemoryFraction(DeviceIndex device);
C10_XPU_API void setMemoryFraction(double fraction, DeviceIndex device);
} // namespace c10::xpu::XPUCachingAllocator

View File

@ -38,7 +38,7 @@ uint32_t crc32_combine (uint32_t crcA, uint32_t crcB, size_t lengthB);
/// compute CRC32 (bitwise algorithm)
uint32_t crc32_bitwise (const void* data, size_t length, uint32_t previousCrc32 = 0);
/// compute CRC32 (half-byte algoritm)
/// compute CRC32 (half-byte algorithm)
uint32_t crc32_halfbyte(const void* data, size_t length, uint32_t previousCrc32 = 0);
#ifdef CRC32_USE_LOOKUP_TABLE_BYTE
@ -96,7 +96,7 @@ uint32_t crc32_16bytes_prefetch(const void* data, size_t length, uint32_t previo
#define __BIG_ENDIAN 4321
#endif
// define endianess and some integer data types
// define endianness and some integer data types
#if defined(_MSC_VER) || defined(__MINGW32__)
// Windows always little endian
#define __BYTE_ORDER __LITTLE_ENDIAN
@ -168,7 +168,7 @@ namespace
/// zlib's CRC32 polynomial
const uint32_t Polynomial = 0xEDB88320;
/// swap endianess
/// swap endianness
static inline uint32_t swap(uint32_t x)
{
#if defined(__GNUC__) || defined(__clang__)
@ -229,7 +229,7 @@ uint32_t crc32_bitwise(const void* data, size_t length, uint32_t previousCrc32)
}
/// compute CRC32 (half-byte algoritm)
/// compute CRC32 (half-byte algorithm)
uint32_t crc32_halfbyte(const void* data, size_t length, uint32_t previousCrc32)
{
uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF
@ -662,7 +662,7 @@ uint32_t crc32_combine(uint32_t crcA, uint32_t crcB, size_t lengthB)
// - if you append length(B) zeros to A and call it A' (think of it as AAAA000)
// and prepend length(A) zeros to B and call it B' (think of it as 0000BBB)
// then exists a C' = A' ^ B'
// - remember: if you XOR someting with zero, it remains unchanged: X ^ 0 = X
// - remember: if you XOR something with zero, it remains unchanged: X ^ 0 = X
// - that means C' = A concat B so that crc(A concat B) = crc(C') = crc(A') ^ crc(B')
// - the trick is to compute crc(A') based on crc(A)
// and crc(B') based on crc(B)

View File

@ -76,7 +76,7 @@ typedef struct mz_zip_archive mz_zip_archive;
// 2) Writing with 1-pass sequential access
// -> We must take care not to require updating values that have already
// been written. We place the variable-length index at the end and do
// not put any indicies into the header to fulfill this constraint.
// not put any index into the header to fulfill this constraint.
// The model.json, which contains all the metadata information,
// should be written as the last file. One reason is that the size of tensor

View File

@ -519,7 +519,7 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoadWithAllocator) {
std::tie(data_ptr, size) = reader.getRecord("key1", &overrideAllocator);
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1);
EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes);
// allcoate with base allocator
// allocate with base allocator
std::tie(data_ptr, size) = reader.getRecord("key1");
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1);
EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes + kBytes1);

View File

@ -2,9 +2,9 @@
## Overview
The LibTorch Stable ABI (Application Binary Interface) provides an interface for extending PyTorch functionality without being tightly coupled to specific PyTorch versions. This enables the development of custom operators and extensions that remain compatible across PyTorch releases.
The LibTorch Stable ABI (Application Binary Interface) provides a limited interface for extending PyTorch functionality without being tightly coupled to specific PyTorch versions. This enables the development of custom operators and extensions that remain compatible across PyTorch releases. This limited set of APIs is not intended to replace existing LibTorch, but rather to provide a stable foundation for a majority of custom extension use cases. If there is any API you would like to see added to the stable ABI, please file a request through a [new issue on the PyTorch repo](https://github.com/pytorch/pytorch/issues).
The stable ABI consists of three main components:
The limited stable ABI consists of three main components:
1. **Stable C headers** - Low-level C API implemented by libtorch (primarily `torch/csrc/inductor/aoti_torch/c/shim.h`)
2. **Header-only C++ library** - Standalone utilities implemented in only headers such that there is no dependence on libtorch (`torch/headeronly/*`)
@ -14,8 +14,8 @@ We discuss each of these in detail
### `torch/headeronly`
This is a set of inlined C++ headers are completely decoupled from libtorch. The headers consist of certain utilities that might be familiar to custom extension writers. For example, the
`c10::ScalarType` enum lives here as `torch::headeronly::ScalarType`.
The inlined C++ headers living in [`torch/headeronly`](https://github.com/pytorch/pytorch/tree/main/torch/headeronly) are completely decoupled from LibTorch. The headers consist of certain utilities that might be familiar to custom extension writers. For example, the
`c10::ScalarType` enum lives here as `torch::headeronly::ScalarType`, as well as a libtorch-independent version of `TORCH_CHECK` that is `STD_TORCH_CHECK`. You can trust all APIs in the `torch::headeronly` namespace to not depend on `libtorch.so`. These APIs are also globally listed in [torch/header_only_apis.txt](https://github.com/pytorch/pytorch/blob/main/torch/header_only_apis.txt).
### `torch/csrc/stable`
@ -34,8 +34,14 @@ We are continuing to improve coverage in our `torch/csrc/stable` APIs. Please fi
### Stable C headers
The stable C headers used by AOTInductor form the foundation of the stable ABI. However, this is **use at your own risk**. For example, users must handle the memory lifecycle of objects returned by certain APIs.
Further, the stack-based APIs discussed below which allow the user to call the PyTorch dispatcher don't provide strong guarantees on forward and backward compatibility.
The stable C headers started by AOTInductor form the foundation of the stable ABI. Presently, the available C headers include:
- [torch/csrc/inductor/aoti_torch/c/shim.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/inductor/aoti_torch/c/shim.h): Includes C-style shim APIs for commonly used regarding Tensors, dtypes, CUDA, and the like.
- [torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h): Includes C-style shim APIs for ATen ops from `native_functions.yaml` (e.g. `aoti_torch_aten_new_empty`).
- [torch/csrc/inductor/aoti_torch/generated/c_shim_*.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/inductor/aoti_torch/generated): Includes C-style shim APIs for specific backend kernels dispatched from `native_functions.yaml` (e.g. `aoti_torch_cuda_pad`). These APIs should only be used for the specific backend they are named after (e.g. `aoti_torch_cuda_pad` should only be used within CUDA kernels), as they opt out of the dispatcher.
- [torch/csrc/stable/c/shim.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/stable/c/shim.h): We are building out more ABIs to logically live in `torch/csrc/stable/c` instead of continuing the AOTI naming that no longer makes sense for our general use case.
These headers are promised to be ABI stable across releases and adhere to a stronger backwards compatibility policy than LibTorch. Specifically, we promise not to modify them for at least 2 years after they are released. However, this is **use at your own risk**. For example, users must handle the memory lifecycle of objects returned by certain APIs. Further, the stack-based APIs discussed below which allow the user to call into the PyTorch dispatcher do not provide strong guarantees on forward and backward compatibility of the underlying op that is called.
Unless absolutely necessary, we recommend the high-level C++ API in `torch/csrc/stable`
which will handle all the rough edges of the C API for the user.

View File

@ -76,6 +76,7 @@
:nosignatures:
empty_cache
get_per_process_memory_fraction
max_memory_allocated
max_memory_reserved
mem_get_info

View File

@ -1106,7 +1106,7 @@ class build_ext(setuptools.command.build_ext.build_ext):
continue
self.copy_file(source_lib, target_lib)
# Delete old rpath and add @loader_lib to the rpath
# This should prevent delocate from attempting to package another instance
# This should prevent deallocate from attempting to package another instance
# of OpenMP library in torch wheel as well as loading two libomp.dylib into
# the address space, as libraries are cached by their unresolved names
install_name_tool_args = [

View File

@ -58,7 +58,8 @@ wrapper__scaled_dot_product_fused_attention_overrideable(
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale) {
std::optional<double> scale,
bool compute_log_sumexp) {
return at::native::openreg::_scaled_dot_product_fused_attention_overrideable(
query,
key,
@ -67,7 +68,8 @@ wrapper__scaled_dot_product_fused_attention_overrideable(
dropout_p,
is_causal,
return_debug_mask,
scale);
scale,
compute_log_sumexp);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>

View File

@ -47,7 +47,8 @@ _scaled_dot_product_fused_attention_overrideable(
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale) {
std::optional<double> scale,
bool compute_log_sumexp) {
const int64_t batch_size = query.size(0);
const int64_t num_heads = query.size(1);
const int64_t head_dim_v = value.size(3);

View File

@ -39,7 +39,8 @@ _scaled_dot_product_fused_attention_overrideable(
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale);
std::optional<double> scale,
bool compute_log_sumexp);
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_scaled_dot_product_fused_attention_overrideable_backward(
const at::Tensor& grad_out,

View File

@ -827,7 +827,7 @@ class TestFullyShardShardPlacementFnMultiProcess(FSDPTest):
torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type)
for iter_idx in range(5):
for _ in range(5):
ref_loss = ref_model(inp).sum()
loss = model(inp).sum()
self.assertEqual(ref_loss, loss)

View File

@ -31,17 +31,17 @@ if TEST_WITH_DEV_DBG_ASAN:
sys.exit(0)
_DISTRIBUTED_STATE_DICT_IMPLS = (
_DISTRIBUTED_STATE_DICT_IMPLS = {
StateDictType.LOCAL_STATE_DICT,
StateDictType.SHARDED_STATE_DICT,
)
}
class TestDistributedCheckpoint(FSDPTest):
@property
def world_size(self):
if torch.cuda.is_available():
gpu_cnt = torch.cuda.device_count()
if torch.accelerator.is_available():
gpu_cnt = torch.accelerator.device_count()
if gpu_cnt < 2:
return gpu_cnt
return 2
@ -93,7 +93,9 @@ class TestDistributedCheckpoint(FSDPTest):
# TODO: add resharding test case.
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestDistributedCheckpoint, globals(), only_for=devices)
devices = ("cuda", "hpu", "xpu")
instantiate_device_type_tests(
TestDistributedCheckpoint, globals(), only_for=devices, allow_xpu=True
)
if __name__ == "__main__":
run_tests()

View File

@ -36,8 +36,8 @@ device_type = torch.device(get_devtype())
class TestApply(FSDPTest):
@property
def world_size(self):
if torch.cuda.is_available():
gpu_cnt = torch.cuda.device_count()
if torch.accelerator.is_available():
gpu_cnt = torch.accelerator.device_count()
if gpu_cnt < 2:
return gpu_cnt
return 2

View File

@ -2,7 +2,6 @@
# Owner(s): ["oncall: distributed"]
import sys
from pathlib import Path
import torch
import torch.distributed as dist
@ -45,53 +44,19 @@ class TestInstantiator(TestCase):
self.assertEqual(return_type_str, "Tuple[Tensor, int, str]")
def test_instantiate_scripted_remote_module_template(self):
dir_path = Path(instantiator.INSTANTIATED_TEMPLATE_DIR_PATH)
# Cleanup.
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
for file_path in file_paths:
file_path.unlink()
# Check before run.
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
num_files_before = len(list(file_paths))
self.assertEqual(num_files_before, 0)
generated_module = instantiator.instantiate_scriptable_remote_module_template(
MyModuleInterface
)
self.assertTrue(hasattr(generated_module, "_remote_forward"))
self.assertTrue(hasattr(generated_module, "_generated_methods"))
# Check after run.
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
num_files_after = len(list(file_paths))
self.assertEqual(num_files_after, 1)
def test_instantiate_non_scripted_remote_module_template(self):
dir_path = Path(instantiator.INSTANTIATED_TEMPLATE_DIR_PATH)
# Cleanup.
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
for file_path in file_paths:
file_path.unlink()
# Check before run.
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
num_files_before = len(list(file_paths))
self.assertEqual(num_files_before, 0)
generated_module = (
instantiator.instantiate_non_scriptable_remote_module_template()
)
self.assertTrue(hasattr(generated_module, "_remote_forward"))
self.assertTrue(hasattr(generated_module, "_generated_methods"))
# Check after run.
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
num_files_after = len(list(file_paths))
self.assertEqual(num_files_after, 1)
if __name__ == "__main__":
run_tests()

View File

@ -64,6 +64,10 @@ class TestDTensorDebugMode(TestCase):
self.assertTrue(isinstance(debug_mode.operators[2], _RedistributeCall))
self.assertEqual(next(iter(debug_mode.operators[1])), torch.ops.aten.mm.default)
# check stringification
self.assertTrue(hasattr(debug_mode.operators[0], "args_str"))
self.assertFalse(hasattr(debug_mode.operators[0], "args"))
def test_debug_string_inside_context(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
@ -267,6 +271,7 @@ class TestDTensorDebugMode(TestCase):
record_torchfunction=True,
record_faketensor=True,
record_tensor_attributes=["a1", "a2"],
store_original_args=True,
) as debug_mode:
torch.matmul(y, x)
@ -279,6 +284,9 @@ class TestDTensorDebugMode(TestCase):
aten::_unsafe_view(t: f32[64, 8], [8, 8, 8])""",
)
self.assertTrue(hasattr(debug_mode.operators[0], "args"))
self.assertEqual(id(debug_mode.operators[0].args[0]), id(y))
@parametrize("has_inner_mode", [True, False])
@parametrize("has_outer_mode", [True, False])
def test_nested_debug_mode(self, has_inner_mode, has_outer_mode):

View File

@ -20,18 +20,18 @@ from torch.distributed.tensor.experimental._attention import (
_cp_options,
_disable_context_parallel_dispatcher,
_enable_context_parallel_dispatcher,
_HeadTailLoadBalancer,
_is_causal_behavior,
_LoadBalancer,
_PerDocumentHeadTailLoadBalancer,
_PTRRLoadBalancer,
_RotateMethod,
context_parallel,
context_parallel_unshard,
set_rotate_method,
)
from torch.distributed.tensor.experimental._cp_custom_ops import flex_cp_allgather
from torch.distributed.tensor.experimental._load_balancer import (
_HeadTailLoadBalancer,
_LoadBalancer,
_PerDocumentHeadTailLoadBalancer,
_PTRRLoadBalancer,
from torch.distributed.tensor.experimental._context_parallel._cp_custom_ops import (
flex_cp_allgather,
)
from torch.distributed.tensor.parallel import parallelize_module
from torch.nn.attention import sdpa_kernel, SDPBackend
@ -52,7 +52,9 @@ from torch.testing._internal.common_cuda import (
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
DTensorTestBase,
map_local_tensor_for_rank,
with_comms,
)
@ -800,11 +802,47 @@ class TestSharding(DTensorTestBase):
chunks = freqs_cis.chunk(self.world_size * 2)
self.assertEqual(
freqs_cis_shard,
torch.cat(
[chunks[self.rank], chunks[self.world_size * 2 - self.rank - 1]], dim=0
map_local_tensor_for_rank(
chunks,
self.rank,
lambda chunks, rank: torch.cat(
[chunks[rank], chunks[self.world_size * 2 - rank - 1]],
dim=0,
),
),
)
RingAttentionTestWithLocalTensor = create_local_tensor_test_class(
RingAttentionTest,
skipped_tests=[
# Need to make attention implementation local tensor friendly, e.g.
# rewrite "rank local" logic
"test_ring_attention_sdpa",
],
)
CPFlexAttentionTestWithLocalTensor = create_local_tensor_test_class(
CPFlexAttentionTest,
skipped_tests=[
# Missing support for batched tensors
"test_cp_flex_attention_causal_mask",
"test_cp_flex_attention_document_mask",
],
)
TestCPCustomOpsWithLocalTensor = create_local_tensor_test_class(
TestCPCustomOps,
skipped_tests=[
# Missing support for fake tensors
"test_flex_cp_custom_op",
],
)
TestShardingWithLocalTensor = create_local_tensor_test_class(
TestSharding,
)
if __name__ == "__main__":
run_tests()

View File

@ -16,6 +16,7 @@ from torch.distributed.tensor import (
from torch.nn import functional as F
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
@ -203,34 +204,42 @@ class DistConvolutionOpsTest(DTensorTestBase):
self.assertTrue(b_dt.grad is not None)
self.assertTrue(x_dt.grad is None)
def _run_single_arg_fwd(self, model, arg) -> tuple[torch.Tensor, torch.Tensor]:
"""Given model and arg, runs fwd model local and distbuted given device_mesh"""
device_mesh = self.build_device_mesh()
model_copy = copy.deepcopy(model).to(device=self.device_type)
dist_model = distribute_module(model, device_mesh, _conv_fn)
arg_dt = DTensor.from_local(arg, device_mesh, [Replicate()])
out_dt = dist_model(arg_dt.to(device=self.device_type))
out = model_copy(arg)
return (out_dt.full_tensor(), out)
@with_comms
def test_conv1d(self):
device_mesh = self.build_device_mesh()
model = nn.Conv1d(64, 64, 3, padding=1)
model_gt = copy.deepcopy(model)
x = torch.randn(1, 64, 8)
x_dt = DTensor.from_local(x, device_mesh, [Replicate()])
model_dt = distribute_module(
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
)
out_dt = model_dt(x_dt)
out = model_gt(x)
x = torch.randn(1, 64, 8, device=self.device_type)
out_dt, out = self._run_single_arg_fwd(model, x)
self.assertEqual(out_dt.shape, out.shape)
@with_comms
def test_conv3d(self):
device_mesh = self.build_device_mesh()
model = nn.Conv3d(64, 64, 3, padding=1)
model_gt = copy.deepcopy(model).to(device=self.device_type)
x = torch.randn(1, 64, 8, 8, 8, device=self.device_type)
x_dt = DTensor.from_local(x, device_mesh, [Replicate()])
model_dt = distribute_module(
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
)
out_dt = model_dt(x_dt)
out = model_gt(x)
out_dt, out = self._run_single_arg_fwd(model, x)
self.assertEqual(out_dt.shape, out.shape)
DistConvolutionOpsTestWithLocalTensor = create_local_tensor_test_class(
DistConvolutionOpsTest,
# Send / recv ops are not supported
skipped_tests=[
"test_conv1d",
"test_conv3d",
"test_conv_backward_none_grad_inp",
"test_depthwise_convolution",
"test_downsampling_convolution",
],
)
if __name__ == "__main__":
run_tests()

View File

@ -520,6 +520,21 @@ class DTensorExportTest(TestCase):
2,
)
def test_union_typed_annotation(self):
def fn(leaf: torch.Tensor | DTensor):
def nest_fn(leaf: torch.Tensor | DTensor):
# def nest_fn(leaf: Union[torch.Tensor, DTensor]): # this works
if isinstance(leaf, DTensor):
leaf = leaf.to_local()
return leaf
return nest_fn(leaf) + 1
z = torch.randn(16, 16)
gm = graph_capture_and_aot_export_joint_with_descriptors(fn, (z,))
self.assertEqual(fn(z), gm(z)[0])
instantiate_parametrized_tests(DTensorExportTest)

View File

@ -887,6 +887,135 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
correct = func(a, b, c, d, ranks=ranks)
self.assertTrue(same(test_out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_custom_estimation_with_fake_tensor_mode(self):
"""Test that custom estimation can use FakeTensorMode for analysis."""
from torch._subclasses.fake_tensor import FakeTensorMode
estimation_calls = 0
def estimate_with_fake_mode(fx_node, compute_multiplier=1.0):
with FakeTensorMode():
nonlocal estimation_calls
estimation_calls += 1
assert isinstance(torch.rand([20]), torch._subclasses.FakeTensor)
return 1.0
patches = get_bucket_patches()
patches["aten_distributed_optimizations.custom_runtime_estimation"] = (
estimate_with_fake_mode
)
def func(a, b, *, ranks):
# Two independent all_gathers that should be bucketed
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
# Matmul that can hide the collectives
mm1 = torch.matmul(a, a)
return ag1.sum() + ag2.sum() + mm1.sum()
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs_a = torch.ones(4, 4, dtype=torch.float, device=device_type)
inputs_b = torch.ones(4, 4, dtype=torch.float, device=device_type) * 2
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
with torch._inductor.config.patch(patches):
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(
compiled, inputs_a, inputs_b
)
# Verify the custom estimation was called
self.assertTrue(
estimation_calls > 0, "Custom estimation should have been called"
)
correct = func(inputs_a, inputs_b, ranks=ranks)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_multidtype_bucketing(self):
"""Test that all_gathers with different dtypes get bucketed together."""
def func(a, b, c, *, ranks):
# Three all_gathers with different dtypes
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks) # float32
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks) # float16
ag3 = _functional_collectives.all_gather_tensor(c, 0, ranks) # float16
# Use all results
return ag1.sum() + ag2.sum() + ag3.sum()
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
a = torch.ones(4, 4, dtype=torch.float32, device=device_type)
b = torch.ones(4, 4, dtype=torch.float16, device=device_type) * 2
c = torch.ones(4, 4, dtype=torch.float16, device=device_type) * 3
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c)
# Should have 1 bucketed all_gather despite different dtypes
FileCheck().check_count(
"torch.ops._c10d_functional.wait_tensor.default", 1, exactly=True
).run(aten_graph_str)
# Verify correctness
correct = func(a, b, c, ranks=ranks)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_basic_all_reduce_bucketing(self):
"""Test that independent all_reduce operations get bucketed together."""
def func(a, b, c):
# Three independent all_reduces that should be bucketed
ar1 = _functional_collectives.all_reduce(a, "sum", "0")
ar2 = _functional_collectives.all_reduce(b, "sum", "0")
ar3 = _functional_collectives.all_reduce(c, "sum", "0")
return ar1.sum() + ar2.sum() + ar3.sum()
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
a = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
b = torch.ones(4, 4, dtype=torch.float, device=device_type) * 2
c = torch.ones(4, 4, dtype=torch.float, device=device_type) * 3
compiled = torch.compile(func)
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c)
# Should see a single bucketed all_reduce
FileCheck().check_count(
"torch.ops._c10d_functional.wait_tensor.default", 1, exactly=True
).run(aten_graph_str)
# Verify correctness
correct = func(a, b, c)
self.assertTrue(same(out, correct))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -0,0 +1,572 @@
# Owner(s): ["module: inductor"]
import unittest
import torch
import torch._dynamo
import torch._dynamo.logging
import torch._dynamo.test_case
import torch.distributed as dist
import torch.fx as fx
# for some reason importing functional collectives after dynamo breaks collectives handling!
from torch._C import FileCheck
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_distributed import requires_accelerator_dist_backend
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils._ordered_set import OrderedSet
# flake8: noqa: B950
# Owner(s): ["module: inductor"]
aten = torch.ops.aten
from torch.testing._internal.common_fsdp import get_devtype
device_type = str(get_devtype())
import torch
import torch._dynamo
import torch._dynamo.logging
import torch._dynamo.test_case
# for some reason importing functional collectives after dynamo breaks collectives handling!
@requires_accelerator_dist_backend(["nccl", "xccl"])
def build_collective_info(graph, hiding_annotations):
"""
Build CollectiveInfo dict from manual hiding annotations.
hiding_annotations: dict mapping collective_start -> hiding_compute_node
"""
from torch._inductor.fx_passes.overlap_scheduling import CollectiveInfo
collective_info = {}
# Find all collective starts and their corresponding waits
start_to_wait = {}
for node in graph.nodes:
if node.op == "call_function" and "wait_tensor" in str(node.target):
wait_input = node.args[0]
if isinstance(wait_input, fx.Node):
start_to_wait[wait_input] = node
# Build CollectiveInfo for each collective
for start_node, wait_node in start_to_wait.items():
hiding_node = hiding_annotations.get(start_node)
# Estimate size and time
size_bytes = 16 * 4 # 4x4 tensor of floats
estimated_time_ms = 1.0 # Dummy time
exposed_time_ms = 0.0 if hiding_node else 1.0 # Hidden if has hiding_node
collective_info[start_node] = CollectiveInfo(
start_node=start_node,
wait_node=wait_node,
size_bytes=size_bytes,
estimated_time_ms=estimated_time_ms,
exposed_time_ms=exposed_time_ms,
hiding_node=hiding_node,
)
return collective_info
def compute_ancestors(graph):
"""Compute ancestor sets for all nodes in the graph."""
node_ancestors = {}
for node in graph.nodes:
ancestors = OrderedSet()
stack = list(node.all_input_nodes)
visited = set()
while stack:
current = stack.pop()
if current in visited:
continue
visited.add(current)
ancestors.add(current)
stack.extend(current.all_input_nodes)
node_ancestors[node] = ancestors
return node_ancestors
@requires_accelerator_dist_backend()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@instantiate_parametrized_tests
class TestOverlapPreservingBucketing(InductorTestCase):
"""
Unit tests for overlap-preserving bucketing pass.
"""
@classmethod
def setUpClass(cls):
super().setUpClass()
from torch.testing._internal.distributed.fake_pg import FakeStore
store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
cls.device = "cuda"
@classmethod
def tearDownClass(cls):
super().tearDownClass()
dist.destroy_process_group()
def test_can_bucket_independent_collectives(self):
"""
Test that independent collectives with separate hiding nodes CAN bucket.
Graph structure:
ag1_start -> ag2_start -> mm1 (hides ag1) -> mm2 (hides ag2) -> ag1_wait -> ag2_wait
"""
def func(a, b):
group_name = "0"
group_size = 1
# Start both collectives
ag1 = torch.ops._c10d_functional.all_gather_into_tensor(
a, group_size, group_name
)
ag2 = torch.ops._c10d_functional.all_gather_into_tensor(
b, group_size, group_name
)
# Independent compute that can hide both
mm1 = torch.mm(a, a)
mm2 = torch.mm(b, b)
# Wait for both
ag1_out = torch.ops._c10d_functional.wait_tensor(ag1)
ag2_out = torch.ops._c10d_functional.wait_tensor(ag2)
return ag1_out.sum() + ag2_out.sum() + mm1.sum() + mm2.sum()
# Use fake mode to trace without executing
with FakeTensorMode():
a = torch.ones(4, 4, device=self.device)
b = torch.ones(4, 4, device=self.device) * 2
# Trace with make_fx
traced = make_fx(func)(a, b)
# Find nodes using find_nodes
ag1, ag2 = traced.graph.find_nodes(
op="call_function",
target=torch.ops._c10d_functional.all_gather_into_tensor.default,
)
mm1, mm2 = traced.graph.find_nodes(
op="call_function", target=torch.ops.aten.mm.default
)
# Manually annotate hiding relationships
hiding_annotations = {
ag1: mm1, # mm1 hides ag1
ag2: mm2, # mm2 hides ag2
}
# Build collective info and ancestors
collective_info = build_collective_info(traced.graph, hiding_annotations)
node_ancestors = compute_ancestors(traced.graph)
scheduled = OrderedSet(traced.graph.nodes)
# Run bucketing
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
OverlapPreservingBucketer,
)
bucketer = OverlapPreservingBucketer(
traced.graph,
collective_info,
node_ancestors,
scheduled,
)
bucketer.bucket_collectives()
# Verify: should have 1 bucketed collective (all_gather_into_tensor_out)
graph_str = str(traced.graph)
FileCheck().check_count("all_gather_into_tensor_out", 1, exactly=False).run(
graph_str
)
def test_cant_bucket_nested_hiding_intervals(self):
"""
Test that nested hiding intervals prevent bucketing.
Graph structure:
ag1_start -> ag2_start -> mm2 (hides ag2) -> ag2_wait -> mm1 (hides ag1) -> ag1_wait
ag2's hiding interval is nested inside ag1's hiding interval.
"""
def func(a, b):
group_name = "0"
group_size = 1
# ag1 starts first
ag1 = torch.ops._c10d_functional.all_gather_into_tensor(
a, group_size, group_name
)
# ag2 starts (inside ag1's interval)
ag2 = torch.ops._c10d_functional.all_gather_into_tensor(
b, group_size, group_name
)
# mm2 hides ag2
mm2 = torch.mm(b[:2, :2], b[:2, :2])
# ag2 waits (still inside ag1's interval)
ag2_out = torch.ops._c10d_functional.wait_tensor(ag2)
# mm1 uses ag2's result and hides ag1
mm1 = torch.mm(a + ag2_out[:4, :4], a)
# ag1 waits last
ag1_out = torch.ops._c10d_functional.wait_tensor(ag1)
return ag1_out.sum() + ag2_out.sum() + mm1.sum() + mm2.sum()
# Use fake mode to trace without executing
with FakeTensorMode():
a = torch.ones(4, 4, device=self.device)
b = torch.ones(4, 4, device=self.device) * 2
# Trace with make_fx
traced = make_fx(func)(a, b)
# Find nodes using find_nodes
ag1, ag2 = traced.graph.find_nodes(
op="call_function",
target=torch.ops._c10d_functional.all_gather_into_tensor.default,
)
mm_nodes = traced.graph.find_nodes(
op="call_function", target=torch.ops.aten.mm.default
)
# mm2 is the first mm, mm1 is the second (based on graph order)
mm2 = mm_nodes[0]
mm1 = mm_nodes[1]
# Manually annotate hiding relationships
hiding_annotations = {
ag1: mm1, # mm1 hides ag1
ag2: mm2, # mm2 hides ag2
}
# Build collective info and ancestors
collective_info = build_collective_info(traced.graph, hiding_annotations)
node_ancestors = compute_ancestors(traced.graph)
scheduled = OrderedSet(traced.graph.nodes)
# Run bucketing
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
OverlapPreservingBucketer,
)
bucketer = OverlapPreservingBucketer(
traced.graph,
collective_info,
node_ancestors,
scheduled,
)
bucketer.bucket_collectives()
# Verify: nested hiding intervals should prevent bucketing
# Should have 2 separate all_gathers, not 1 bucketed one
graph_str = str(traced.graph)
FileCheck().check_count("all_gather_into_tensor", 2, exactly=False).run(
graph_str
)
@parametrize("final_mm_hidden", (True, False))
def test_cant_bucket_ag_with_rs_hiding_interval_between(self, final_mm_hidden):
"""
Test that all_gathers can't bucket when a reduce_scatter's hiding interval is between them.
Graph structure:
ag1_start -> mm1 (hides ag1) -> ag1_wait ->
rs_start -> mm2 (hides rs) -> rs_wait ->
if final_mm_hidden:
ag2_start -> mm3 (hides ag2) -> ag2_wait
if final_mm_hidden:
Bucketing ag1 and ag2 would require moving one of them, which would break hiding relationships:
- Moving ag2 earlier would break ag2's hiding by mm3
- Moving ag1 later would break ag1's hiding by mm1
- The rs hiding interval creates an obstacle between them
otherwise, we can bucket
"""
def func(a, b, c):
group_name = dist.distributed_c10d._get_default_group().group_name
group_size = 1
# First all_gather
ag1 = torch.ops._c10d_functional.all_gather_into_tensor(
a, group_size, group_name
)
mm1 = torch.mm(a, a) # hides ag1
ag1_out = torch.ops._c10d_functional.wait_tensor(ag1)
# Reduce scatter in between
rs = torch.ops._c10d_functional.reduce_scatter_tensor(
b, "sum", group_size, group_name
)
mm2 = torch.mm(b[:4, :4], b[:4, :4]) # hides rs
rs_out = torch.ops._c10d_functional.wait_tensor(rs)
# Second all_gather
ag2 = torch.ops._c10d_functional.all_gather_into_tensor(
c, group_size, group_name
)
mm3 = torch.mm(c, c) # hides ag2
ag2_out = torch.ops._c10d_functional.wait_tensor(ag2)
return ag1_out.sum() + rs_out.sum() + ag2_out.sum(), mm1, mm2, mm3
# Use fake mode to trace without executing
with FakeTensorMode():
a = torch.ones(4, 4, device=self.device)
b = torch.ones(8, 4, device=self.device)
c = torch.ones(4, 4, device=self.device)
# Trace with make_fx
traced = make_fx(func)(a, b, c)
ag1, ag2 = traced.graph.find_nodes(
op="call_function",
target=torch.ops._c10d_functional.all_gather_into_tensor.default,
)
(rs,) = traced.graph.find_nodes(
op="call_function",
target=torch.ops._c10d_functional.reduce_scatter_tensor.default,
)
mm1, mm2, mm3 = traced.graph.find_nodes(
op="call_function", target=torch.ops.aten.mm.default
)
# Manually annotate hiding relationships
hiding_annotations = {
ag1: mm1, # mm1 hides ag1
# rs: mm2, # mm2 hides rs
ag2: mm3,
}
if final_mm_hidden:
hiding_annotations[rs] = mm2
# Build collective info and ancestors
collective_info = build_collective_info(traced.graph, hiding_annotations)
node_ancestors = compute_ancestors(traced.graph)
scheduled = OrderedSet(traced.graph.nodes)
# Run bucketing logic to find buckets (without applying them, which would require process groups)
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
OverlapPreservingBucketer,
)
bucketer = OverlapPreservingBucketer(
traced.graph,
collective_info,
node_ancestors,
scheduled,
)
bucketer.bucket_collectives()
graph_str = str(traced.graph)
# check order of mms preserved
FileCheck().check("%mm").check("%mm_1").check("%mm_2").run(graph_str)
if final_mm_hidden:
# Should NOT bucket - 2 separate all_gathers
# Count all_gather node names (works even when wrapped in control_deps)
FileCheck().check_count("%all_gather_into_tensor", 2, exactly=False).run(
graph_str
)
else:
# Should bucket - 1 bucketed all_gather (all_gather_into_tensor_out)
FileCheck().check_count(
"%all_gather_into_tensor_out", 1, exactly=False
).run(graph_str)
def test_can_bucket_all_reduce(self):
"""
Test that all_reduce operations CAN bucket together.
Graph structure:
ar1_start -> ar2_start -> mm1 (hides ar1) -> mm2 (hides ar2) -> ar1_wait -> ar2_wait
"""
def func(a, b):
group_name = "0"
# Start both all_reduce operations
ar1 = torch.ops._c10d_functional.all_reduce(a, "sum", group_name)
ar2 = torch.ops._c10d_functional.all_reduce(b, "sum", group_name)
# Independent compute that can hide both
mm1 = torch.mm(a, a)
mm2 = torch.mm(b, b)
# Wait for both
ar1_out = torch.ops._c10d_functional.wait_tensor(ar1)
ar2_out = torch.ops._c10d_functional.wait_tensor(ar2)
return ar1_out.sum() + ar2_out.sum() + mm1.sum() + mm2.sum()
# Use fake mode to trace without executing
with FakeTensorMode():
a = torch.ones(4, 4, device=self.device)
b = torch.ones(4, 4, device=self.device) * 2
# Trace with make_fx
traced = make_fx(func)(a, b)
# Find nodes
ar1, ar2 = traced.graph.find_nodes(
op="call_function",
target=torch.ops._c10d_functional.all_reduce.default,
)
mm1, mm2 = traced.graph.find_nodes(
op="call_function", target=torch.ops.aten.mm.default
)
# For all_reduce, start_node == wait_node (no separate wait)
hiding_annotations = {
ar1: mm1,
ar2: mm2,
}
# Build collective info
collective_info = build_collective_info(traced.graph, hiding_annotations)
node_ancestors = compute_ancestors(traced.graph)
scheduled = OrderedSet(traced.graph.nodes)
# Run bucketing
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
OverlapPreservingBucketer,
)
bucketer = OverlapPreservingBucketer(
traced.graph,
collective_info,
node_ancestors,
scheduled,
)
bucketer.bucket_collectives()
# Verify: should have 1 bucketed all_reduce
# After bucketing, there should be only one all_reduce node (the bucketed one)
graph_str = str(traced.graph)
FileCheck().check_count("%all_reduce", 1, exactly=True).check_count(
"%mm", 2
).run(graph_str)
def test_can_bucket_multidtype_collectives(self):
"""
Test that all_gathers with different dtypes CAN bucket together.
Graph structure:
ag1_float32 -> mm1 (hides ag1) -> ag1_wait
ag2_bfloat16 -> mm2 (hides ag2) -> ag2_wait
"""
def func(a, b):
group_name = "0"
group_size = 1
# Start both collectives with different dtypes
ag1 = torch.ops._c10d_functional.all_gather_into_tensor(
a,
group_size,
group_name, # float32
)
ag2 = torch.ops._c10d_functional.all_gather_into_tensor(
b,
group_size,
group_name, # bfloat16
)
# Independent compute that can hide both
mm1 = torch.mm(a, a)
mm2 = torch.mm(b.float(), b.float())
# Wait for both
ag1_out = torch.ops._c10d_functional.wait_tensor(ag1)
ag2_out = torch.ops._c10d_functional.wait_tensor(ag2)
return ag1_out.sum() + ag2_out.sum() + mm1.sum() + mm2.sum()
# Use fake mode to trace without executing
with FakeTensorMode():
a = torch.ones(4, 4, device=self.device, dtype=torch.float32)
b = torch.ones(4, 4, device=self.device, dtype=torch.bfloat16)
# Trace with make_fx
traced = make_fx(func)(a, b)
# Find nodes using find_nodes
ag1, ag2 = traced.graph.find_nodes(
op="call_function",
target=torch.ops._c10d_functional.all_gather_into_tensor.default,
)
mm_nodes = traced.graph.find_nodes(
op="call_function", target=torch.ops.aten.mm.default
)
mm1 = mm_nodes[0]
mm2 = mm_nodes[1]
# Manually annotate hiding relationships
hiding_annotations = {
ag1: mm1, # mm1 hides ag1
ag2: mm2, # mm2 hides ag2
}
# Build collective info and ancestors
collective_info = build_collective_info(traced.graph, hiding_annotations)
node_ancestors = compute_ancestors(traced.graph)
scheduled = OrderedSet(traced.graph.nodes)
# Run bucketing with multidtype mode
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
OverlapPreservingBucketer,
)
bucketer = OverlapPreservingBucketer(
traced.graph,
collective_info,
node_ancestors,
scheduled,
bucket_mode="custom_ops_multidtype",
)
bucketer.bucket_collectives()
# Verify: should have 1 bucketed collective (all_gather_into_tensor_out)
# even though dtypes are different
graph_str = str(traced.graph)
FileCheck().check_count("all_gather_into_tensor_out", 1, exactly=False).run(
graph_str
)
if __name__ == "__main__":
run_tests()

View File

@ -41,6 +41,20 @@ from torch.testing._internal.triton_utils import requires_cuda_and_triton
from torch.testing._internal.two_tensor import TwoTensor
def aot_eager_regional_inductor():
"""
Regional inductor backend for AOT autograd.
Uses regional_inductor as both forward and backward compiler.
"""
from torch._dynamo.backends.common import aot_autograd
from torch.fx.passes.regional_inductor import regional_inductor
return aot_autograd(
fw_compiler=regional_inductor,
bw_compiler=regional_inductor,
)
def saved_tensors_hooks_to_gm(
pack_fn,
unpack_fn,
@ -1898,6 +1912,171 @@ class AOTAutogradCacheTests(InductorTestCase):
# no recompiles
self.assertFalse(counters)
@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", True)
@functorch_config.patch({"enable_autograd_cache": True})
@functorch_config.patch({"bundled_autograd_cache": True})
def test_regional_inductor_basic(self):
"""
Basic test for regional inductor with bundled autograd cache.
Tests that regional inductor compilation results can be cached and hit.
"""
import torch.fx.traceback as fx_traceback
def fn(x, y):
sin = torch.sin(x)
# Mark this region to be compiled with inductor
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 1
return torch.sin(add)
x = torch.randn(10, device="cpu")
y = torch.randn(10, device="cpu")
# Compile with regional inductor backend
compiled_fn = torch.compile(
fn, backend=aot_eager_regional_inductor(), fullgraph=True
)
# First call should miss in cache
result1 = compiled_fn(x, y)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
# Second call should hit (after clearing dynamo)
self._clear_dynamo_and_codecache()
result2 = compiled_fn(x, y)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
# Results should be the same
self.assertEqual(result1, result2)
@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", True)
@functorch_config.patch({"enable_autograd_cache": True})
@functorch_config.patch({"bundled_autograd_cache": True})
def test_regional_inductor_with_backward(self):
"""
Test regional inductor with backward pass and bundled autograd cache.
Note: Regional inductor triggers multiple AOT autograd compilations:
- One for the outer graph (with regional inductor backend)
- One for each marked region (via standalone_compile)
"""
import torch.fx.traceback as fx_traceback
def fn(x, y):
sin = torch.sin(x)
# Mark this region to be compiled with inductor
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 1
return torch.sin(add)
x = torch.randn(10, requires_grad=True)
y = torch.randn(10, requires_grad=True)
x2 = x.detach().clone().requires_grad_(True)
y2 = y.detach().clone().requires_grad_(True)
# Compile with regional inductor backend
compiled_fn = torch.compile(
fn, backend=aot_eager_regional_inductor(), fullgraph=True
)
# First call: AOT autograd compiles the outer graph (1 miss)
# Regional inductor then compiles the marked region (1 more miss)
result1 = compiled_fn(x, y)
result1.sum().backward()
# We expect 2 cache misses: outer graph + marked region
initial_misses = counters["aot_autograd"]["autograd_cache_miss"]
initial_saves = counters["aot_autograd"]["autograd_cache_saved"]
self.assertGreater(initial_misses, 0)
self.assertGreater(initial_saves, 0)
# Second call should hit (after clearing dynamo)
self._clear_dynamo_and_codecache()
result2 = compiled_fn(x2, y2)
result2.sum().backward()
# Should have cache hits now
final_hits = counters["aot_autograd"]["autograd_cache_hit"]
self.assertGreater(final_hits, 0)
# Cache misses and saves should not increase
self.assertEqual(
counters["aot_autograd"]["autograd_cache_miss"], initial_misses
)
self.assertEqual(
counters["aot_autograd"]["autograd_cache_saved"], initial_saves
)
# Results and gradients should be the same
self.assertEqual(result1, result2)
self.assertEqual(x.grad, x2.grad)
self.assertEqual(y.grad, y2.grad)
@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", True)
@functorch_config.patch({"enable_autograd_cache": True})
@functorch_config.patch({"bundled_autograd_cache": True})
def test_regional_inductor_cache_miss_on_change(self):
"""
Test that changing the function causes a cache miss with regional inductor.
Regional inductor creates multiple AOT compilations, so we track
the change in cache misses rather than absolute counts.
"""
import torch.fx.traceback as fx_traceback
def fn1(x, y):
sin = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 1
return torch.sin(add)
def fn2(x, y):
sin = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 2 # Changed from +1 to +2
return torch.sin(add)
x = torch.randn(10)
y = torch.randn(10)
# Compile first function
compiled_fn1 = torch.compile(
fn1, backend=aot_eager_regional_inductor(), fullgraph=True
)
result1 = compiled_fn1(x, y)
first_misses = counters["aot_autograd"]["autograd_cache_miss"]
first_saves = counters["aot_autograd"]["autograd_cache_saved"]
self.assertGreater(first_misses, 0)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertGreater(first_saves, 0)
# Compile second function (different graph)
self._clear_dynamo_and_codecache()
compiled_fn2 = torch.compile(
fn2, backend=aot_eager_regional_inductor(), fullgraph=True
)
result2 = compiled_fn2(x, y)
# Should miss because graph is different (more misses than before)
self.assertGreater(
counters["aot_autograd"]["autograd_cache_miss"], first_misses
)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertGreater(
counters["aot_autograd"]["autograd_cache_saved"], first_saves
)
# Results should be different
self.assertNotEqual(result1, result2)
@functorch_config.patch({"bundled_autograd_cache": True})
class AOTAutogradCacheBundledTests(AOTAutogradCacheTests):

View File

@ -2064,6 +2064,23 @@ Detected recompile when torch.compile stance is 'fail_on_recompile'. filename: '
self.assertEqual(f(), 1)
def test_error_on_graph_break_nonempty_checkpoint(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnts)
def fn(x):
x = x + 1
x = x + 1
x = x + 1
with torch._dynamo.error_on_graph_break(True):
torch._dynamo.graph_break()
return x + 1
with self.assertRaises(Unsupported):
fn(torch.ones(3))
self.assertEqual(cnts.frame_count, 0)
def test_nested_compile_fullgraph(self):
# Test that fullgraph=True cannot be toggled back by fullgraph=False
inp = torch.ones(3)

View File

@ -341,7 +341,7 @@ class DictTests(torch._dynamo.test_case.TestCase):
def fn(x, d):
y = 0
for idx, (key, value) in enumerate(d.items()):
for idx, value in enumerate(d.values()):
if idx == 0:
y += torch.sin(x * value)
else:
@ -366,7 +366,7 @@ class DictTests(torch._dynamo.test_case.TestCase):
def fn(x, d):
y = 0
for idx, (key, value) in enumerate(d.items()):
for idx, value in enumerate(d.values()):
if idx == 0:
y += torch.sin(x * value)
else:
@ -847,7 +847,7 @@ class DictTests(torch._dynamo.test_case.TestCase):
d = {"a": 2, "b": 3, "c": 5 * x}
mp = types.MappingProxyType(d)
y = torch.sin(x * mp["a"])
for k, v in mp.items(): # noqa: PERF102
for v in mp.values():
y += torch.cos(x * v)
return mp
@ -864,7 +864,7 @@ class DictTests(torch._dynamo.test_case.TestCase):
def fn(x):
mp = types.MappingProxyType(d)
y = torch.sin(x * mp["a"])
for k, v in mp.items(): # noqa: PERF102
for v in mp.values():
y += torch.cos(x * v)
d["d"] = 4
return mp
@ -885,7 +885,7 @@ class DictTests(torch._dynamo.test_case.TestCase):
def fn(x, mp):
y = torch.sin(x * mp["a"])
for k, v in mp.items(): # noqa: PERF102
for v in mp.values():
y += torch.cos(x * v)
if isinstance(mp, types.MappingProxyType):
y *= 2

View File

@ -1159,6 +1159,7 @@ User code traceback:
torch._dynamo.graph_break()
NOTE: the most recent `torch.compile` tracing attempt might not be where you applied `torch.compile`! This is due to how graph breaks are implemented - the optimized code object returned by Dynamo will call another Dynamo-generated resume function and tracing is re-enabled by calling the resume function as a normal Python function, which Dynamo intercepts as a top-level frame.
Most recent bytecode instructions traced (max 20):
TRACE RESUME 0 []
TRACE LOAD_FAST 'x' []
@ -1172,7 +1173,8 @@ TRACE STORE_FAST 'z' [TensorVariable()]
TRACE LOAD_GLOBAL 'torch' []
TRACE LOAD_ATTR '_dynamo' [LazyVariableTracker(unrealized: <class 'module'>)]
TRACE LOAD_ATTR 'graph_break' [LazyVariableTracker(unrealized: <class 'module'>)]
TRACE CALL 0 [NullVariable, LazyVariableTracker(unrealized: <class 'function'>)]""",
TRACE CALL 0 [NullVariable, LazyVariableTracker(unrealized: <class 'function'>)]
""",
)
@torch._dynamo.config.patch(verbose=True)
@ -1234,17 +1236,28 @@ TRACE CALL 0 [NullVariable, LazyVariableTracker(unrealized: <class 'function'>)]
self.assertIn("Foo().attr = x # 1", records[-1].getMessage())
def post_munge(s):
return re.sub(
s = re.sub(
r"torch_dynamo_resume_in_f(\d)_at_(\d+)",
r"torch_dynamo_resume_in_f\1_at_N",
s,
)
# remove most recent bytecode instructions
# DOTALL is needed to entirely remove TRACE ... lines (including the newline)
return re.sub(r"TRACE.*$", "", s, flags=re.DOTALL)
self.assertExpectedInline(
post_munge(munge_exc(records[-1].getMessage(), skip=0)),
"""\
Graph break in user code at test_error_messages.py:N
Graph Break Reason: STORE_ATTR-caused graph break
Graph Break Reason: Encountered graph break when attempting to store an object's attribute (STORE_ATTR):
Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
User code traceback:
File "test_error_messages.py", line N, in test_graph_break_traceback_above_dynamo_shows_user_code
f3(torch.randn(3))
@ -1257,8 +1270,12 @@ User code traceback:
File "test_error_messages.py", line N, in torch_dynamo_resume_in_f3_at_N
Foo().attr = x
File "test_error_messages.py", line N, in __setattr__
torch._dynamo.graph_break()
NOTE: the most recent `torch.compile` tracing attempt might not be where you applied `torch.compile`! This is due to how graph breaks are implemented - the optimized code object returned by Dynamo will call another Dynamo-generated resume function and tracing is re-enabled by calling the resume function as a normal Python function, which Dynamo intercepts as a top-level frame.
Most recent bytecode instructions traced (max 20):
""",
)
@ -1483,6 +1500,110 @@ from user code:
):
fn(torch.randn(3))
@make_logging_test(graph_breaks=True)
def test_step_graph_break(self, records):
@torch.compile(backend="eager")
def fn(x):
x = x + 1
x = x + 2
torch._dynamo.step_unsupported()
return x + 4
fn(torch.ones(3))
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break in user code at test_error_messages.py:N
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
User code traceback:
File "test_error_messages.py", line N, in test_step_graph_break
fn(torch.ones(3))
File "test_error_messages.py", line N, in fn
torch._dynamo.step_unsupported()
""",
)
torch._dynamo.reset()
with torch._dynamo.error_on_graph_break(True):
self.assertExpectedInlineMunged(
Unsupported,
lambda: fn(torch.ones(3)),
"""\
cannot resume from torch._dynamo.step_unsupported()
Explanation: traced torch._dynamo.step_unsupported(), but Dynamo is instructed to error on graph break. This graph break is used for debugging only.
Hint: Remove the torch._dynamo.step_unsupported() call.
Hint: Make sure fullgraph=False and error_on_graph_break=False.
Hint: This is likely to be a Dynamo bug. Please report an issue to PyTorch.
Developer debug context:
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0283.html
from user code:
File "test_error_messages.py", line N, in fn
torch._dynamo.step_unsupported()""",
)
@make_logging_test(graph_breaks=True)
def test_store_attr_graph_break(self, records):
class Foo:
def __setattr__(self, name, value):
torch._dynamo.graph_break()
@torch.compile(backend="eager")
def fn(x):
Foo().attr = x
fn(torch.ones(3))
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break in user code at test_error_messages.py:N
Graph Break Reason: Encountered graph break when attempting to store an object's attribute (STORE_ATTR):
Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
User code traceback:
File "test_error_messages.py", line N, in test_store_attr_graph_break
fn(torch.ones(3))
File "test_error_messages.py", line N, in fn
Foo().attr = x
File "test_error_messages.py", line N, in __setattr__
torch._dynamo.graph_break()
""",
)
torch._dynamo.reset()
with torch._dynamo.error_on_graph_break(True):
self.assertExpectedInlineMunged(
Unsupported,
lambda: fn(torch.ones(3)),
"""\
Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
from user code:
File "test_error_messages.py", line N, in fn
Foo().attr = x
File "test_error_messages.py", line N, in __setattr__
torch._dynamo.graph_break()""",
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -363,6 +363,40 @@ class FxGraphRunnableTest(TestCase):
self._exec_and_verify_payload()
def test_metrics_context(self):
"""
When TORCH_COMPILE_DEBUG is set, provenance_tracking_level is set to 1, and
the generated fx_graph_runnable crashed with,
RuntimeError: Cannot add inductor_provenance outside of a MetricsContext
"""
import torch._inductor.config as inductor_config
def f(x):
return x * 2 + 1
# Enable provenance tracking to trigger the code path that adds metrics
with inductor_config.patch(
{"trace.enabled": True, "trace.provenance_tracking_level": 1}
):
x = torch.randn(4, 4)
torch.compile(f)(x)
self._exec_and_verify_payload()
@torch._dynamo.config.patch(assume_static_by_default=False)
def test_dynamic_expression(self):
"""
Test not emitting something like "s27*s53**2 = 36"
"""
def f(x):
return torch.ops.aten._adaptive_avg_pool2d(
x, (6, 6)
), torch.ops.aten._adaptive_avg_pool2d(x + 1, (2, 5))
x = torch.randn(2, 4, 16, 16)
torch.compile(f)(x)
self._exec_and_verify_payload()
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -2858,7 +2858,7 @@ class GraphModule(torch.nn.Module):
def fn(x):
return wrap(lambda x: model(x), x)
for i in range(2):
for _ in range(2):
# second iteration is key, hooks would have fired during aot trace
# on first iter
activations.clear()

View File

@ -807,7 +807,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
def __init__(self) -> None:
super().__init__()
self.layers = torch.nn.ModuleList()
for i in range(10):
for _ in range(10):
layer = torch.nn.Linear(16, 16)
layer.register_forward_pre_hook(lambda _, inp: fw_hook(inp))
layer = torch.compile(layer, backend=cnts)

View File

@ -4347,6 +4347,33 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
res = opt_fn(x)
self.assertTrue(same(ref, res))
def test_tying_union_new_syntax(self):
def fn(x):
def inner1(y: torch.Tensor | None):
return y
def inner2(y: None | torch.Tensor):
return y
def inner3(y: torch.Tensor | list[int]):
return y
return x + 1
torch.compile(fn, backend="eager", fullgraph=True)(torch.ones(3))
@unittest.expectedFailure
def test_typing_union_new_syntax_reconstruct(self):
def fn(x):
return (
x + 1,
torch.Tensor | None,
None | torch.Tensor,
torch.Tensor | list[int],
)
torch.compile(fn, backend="eager", fullgraph=True)(torch.ones(3))
def test_optimize_on_module(self):
class MockModule(torch.nn.Module):
def __init__(self) -> None:

View File

@ -262,7 +262,7 @@ class ConstLoop(torch.nn.Module):
self.count = 3
def forward(self, x):
for i in range(self.count):
for _ in range(self.count):
x = torch.sigmoid(self.linear1(x))
return x
@ -509,7 +509,7 @@ class CfgModule(torch.nn.Module):
self.layer = torch.nn.Linear(10, 10)
def forward(self, x):
for i in range(self.cfg.count):
for _ in range(self.cfg.count):
x = self.layer(x + self.cfg.val)
return x
@ -781,7 +781,7 @@ class ParametersModule5(torch.nn.Module):
def forward(self, x):
counter = 0
for param in self.parameters():
for _param in self.parameters():
counter += 1
return x * self.scale * counter
@ -841,7 +841,7 @@ class EnumValues(torch.nn.ModuleDict):
def forward(self, init_features):
features = [init_features]
for idx, layer in enumerate(self.values()):
for layer in self.values():
new_features = layer(features)
features.append(new_features)
return torch.cat(features, 1)
@ -2161,7 +2161,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")
opt_mod = torch.compile(fn, backend=cnts)
for i in range(8):
for _ in range(8):
mod = Mod()
opt_mod(torch.randn(5, 5), mod)
@ -2516,7 +2516,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
compiled_model = torch.compile(model, backend="aot_eager")
activations = compiled_activations
for i in range(2):
for _ in range(2):
# second iteration is key, hooks would have fired during aot trace
# on first iter
compiled_activations.clear()
@ -2526,7 +2526,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
loss.backward()
activations = eager_activations
for i in range(2):
for _ in range(2):
# second iteration is key, hooks would have fired during aot trace
# on first iter
eager_activations.clear()
@ -2575,12 +2575,12 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
def save_activations(mod, inp, out):
activations.append(inp)
for name, module in model.named_modules():
for module in model.modules():
module.register_forward_hook(save_activations)
cnt = torch._dynamo.testing.CompileCounter()
model = torch.compile(model, backend=cnt, fullgraph=True)
for i in range(2):
for _ in range(2):
# second iteration is key, hooks would have fired during aot trace
# on first iter
activations.clear()
@ -2703,7 +2703,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
model = torch.compile(model, backend="aot_eager")
for i in range(2):
for _ in range(2):
# second iteration is key, hooks would have fired during aot trace
# on first iter
x = torch.randn((20, 10))

View File

@ -380,6 +380,41 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreak
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 13)
def test_dead_nested_cells(self):
global f1, f2, f3
def f3(x, cell1):
cell1 += 2
x = x + cell1
torch._dynamo.graph_break()
return x + cell1
def f1(cell1=0):
def inner(x):
x += 4
x = f3(x, cell1)
return x + 8
return inner
def f2(x):
return f1()(x + 16) + 32
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f2)
x = torch.zeros(3)
res = f2(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
# If we don't handle dead cells in nested functions correctly,
# frame_count will increase since we also
# graph break when we attempt to codegen inner.
# The exact issue was that side_effects was failing to codegen inner's cell's creation.
# So when we try to codegen cells for resume functions, we end up trying to codegen
# a CellVariable without a source, which leads to a graph break we can't resume from.
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 6)
def test_cells_double_graph_break(self):
def f1(x1):
cell1 = x1 + 1
@ -806,6 +841,39 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreak
)
)
def test_disable_nested_graph_breaks(self):
global f1, f2, f3, f4, f5
def f1(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
def f2(x):
return f1(x + 4) + 8
# NOTE since the disable_nested_graph_breaks decorator is implemented as a
# context manager, we don't need to separately test context manager usage.
@torch._dynamo.disable_nested_graph_breaks
def f3(x):
return f2(x + 16) + 32
def f4(x):
return f3(x + 64) + 128
def f5(x):
return f4(x + 256) + 512
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f5)
x = torch.zeros(3)
res = f5(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
# 2 frames from each of f5+f4, f3, f2, f1
self.assertEqual(cnts.frame_count, 8)
self.assertEqual(cnts.op_count, 10)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1,13 +1,16 @@
# Owner(s): ["module: dynamo"]
import functools
from typing import TYPE_CHECKING
import torch
import torch._inductor.test_case
import torch.fx.traceback as fx_traceback
import torch.utils.checkpoint
from torch._dynamo.backends.common import aot_autograd
from torch._functorch._aot_autograd.autograd_cache import BundledCompiledForward
from torch._guards import detect_fake_mode
from torch._inductor.output_code import RegionalOutputCode
from torch._inductor.test_case import run_tests
from torch._inductor.utils import run_fw_bw_and_get_code
from torch.fx._graph_pickler import GraphPickler
@ -21,6 +24,10 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.triton_utils import requires_cuda_and_triton
if TYPE_CHECKING:
from torch._inductor.compile_fx import _CompileFxKwargs
# Open questions / follow-ups
# 1) CSE behavior with meta custom nodes
# Common subexpression elimination may not differentiate between distinct meta
@ -462,5 +469,154 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
self.assertEqual(len(codes), 2)
@skipIfTorchDynamo("Not a suitable dynamo wrapped test")
class TestRegionalOutputCode(torch._inductor.test_case.TestCase):
"""Tests for RegionalOutputCode and BundledAOTAutogradResult."""
def test_regional_output_code_serialization(self):
"""Test that RegionalOutputCode can be serialized and deserialized."""
def fn(x, y):
sin = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 1
return torch.sin(add)
x = torch.randn(10, requires_grad=True)
y = torch.randn(10, requires_grad=True)
# Compile with regional inductor
with torch.fx.traceback.preserve_node_meta(enable=False):
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
fake_mode = FakeTensorMode()
with fake_mode:
fake_x = fake_mode.from_tensor(x)
fake_y = fake_mode.from_tensor(y)
gm = make_fx(fn)(fake_x, fake_y)
# Run regional_inductor on the graph
result_gm = regional_inductor(gm, fake_x, fake_y)
# Create RegionalOutputCode
output_code = RegionalOutputCode(result_gm)
# Test that we can call it
self.assertIsNotNone(output_code._graph_module)
# Serialize
output_code.prepare_for_serialization()
self.assertIsNone(output_code._graph_module)
self.assertIsNotNone(output_code._serialized_graph_module)
# Deserialize via post_compile
from torch._inductor.output_code import CompiledFxGraphConstants
fx_config: _CompileFxKwargs = {"is_backward": False}
output_code.post_compile(
[fake_x, fake_y], CompiledFxGraphConstants(), fx_config
)
self.assertIsNotNone(output_code._graph_module)
self.assertIsInstance(output_code._graph_module, torch.fx.GraphModule)
# Test that deserialized graph works
with fake_mode:
result = output_code([fake_x, fake_y])
self.assertIsNotNone(result)
def test_regional_output_code_with_backward(self):
"""Test RegionalOutputCode with both forward and backward compilation."""
def fn(x, y):
sin = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 1
return torch.sin(add)
x = torch.randn(10, requires_grad=True)
y = torch.randn(10, requires_grad=True)
# Compile with regional inductor backend
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
fake_mode = FakeTensorMode()
with fake_mode:
fake_x = fake_mode.from_tensor(x)
fake_y = fake_mode.from_tensor(y)
# Create forward graph
with torch.fx.traceback.preserve_node_meta(enable=False):
gm = make_fx(fn)(fake_x, fake_y)
forward_gm = regional_inductor(gm, fake_x, fake_y)
# Create forward output code
fw_code = RegionalOutputCode(forward_gm)
# Verify it can be called
with fake_mode:
result = fw_code([fake_x, fake_y])
self.assertIsNotNone(result)
# Test serialization round-trip
fw_code.prepare_for_serialization()
# Deserialize via post_compile
from torch._inductor.output_code import CompiledFxGraphConstants
fx_config: _CompileFxKwargs = {"is_backward": False}
fw_code.post_compile([fake_x, fake_y], CompiledFxGraphConstants(), fx_config)
with fake_mode:
result2 = fw_code([fake_x, fake_y])
self.assertIsNotNone(result2)
def test_regional_compiled_forward_backward(self):
"""Test BundledCompiledForward and BundledCompiledBackward with RegionalOutputCode."""
def fn(x):
with fx_traceback.annotate({"compile_with_inductor": 0}):
return torch.sin(x) * 2
x = torch.randn(5, requires_grad=True)
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
fake_mode = FakeTensorMode()
with fake_mode:
fake_x = fake_mode.from_tensor(x)
with torch.fx.traceback.preserve_node_meta(enable=False):
gm = make_fx(fn)(fake_x)
compiled_gm = regional_inductor(gm, fake_x)
# Create forward using the generic BundledCompiledForward
fw_code = RegionalOutputCode(compiled_gm)
fw_compiled = BundledCompiledForward[RegionalOutputCode](result=fw_code)
# Test pre_save
fw_compiled.pre_save()
# After pre_save, fw_compiled.result is a copy with serialized graph
self.assertIsNotNone(fw_compiled.result._serialized_graph_module)
self.assertIsNone(
fw_compiled.result._graph_module
) # Should be cleared after serialization
# Test load (doesn't deserialize yet)
loaded_code = fw_compiled.load([fake_x])
self.assertIsNone(loaded_code._graph_module) # Not yet deserialized
self.assertIsNotNone(loaded_code._serialized_graph_module)
fx_config: _CompileFxKwargs = {"is_backward": False}
post_compiled = fw_compiled.post_compile(loaded_code, fx_config)
self.assertIsNotNone(post_compiled)
self.assertIsNotNone(post_compiled._graph_module) # Now deserialized
if __name__ == "__main__":
run_tests()

View File

@ -697,7 +697,7 @@ class UnspecTests(torch._dynamo.test_case.TestCase):
@torch._dynamo.config.patch(specialize_float=False, capture_scalar_outputs=True)
def test_unspecialized_float_multiply_precision(self):
dtypes = [torch.bfloat16, torch.float16, torch.float32, torch.float64]
for i, dtype in enumerate(dtypes):
for dtype in dtypes:
def fn(x, y):
return x * y
@ -722,7 +722,7 @@ class UnspecTests(torch._dynamo.test_case.TestCase):
return x + y.item()
dtypes = [torch.bfloat16, torch.float16, torch.float32, torch.float64]
for i, dtype in enumerate(dtypes):
for dtype in dtypes:
x = torch.ones(3, 3, dtype=dtype)
self.assertEqual(f(x), x + x.sum().item())

View File

@ -600,6 +600,8 @@ def forward(self, x):
in_ptr1,
out_ptr,
n_elements,
fval,
ival,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
@ -608,7 +610,7 @@ def forward(self, x):
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
output = x + y + fval + ival
tl.store(out_ptr + offsets, output, mask=mask)
def custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
@ -618,7 +620,9 @@ def forward(self, x):
def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16)
wrap_triton(add_kernel)[grid](
x, y, output, n_elements, 3.14, 42, BLOCK_SIZE=16
)
return output
@ -633,7 +637,9 @@ def forward(self, x):
def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16, num_warps=8)
wrap_triton(add_kernel)[grid](
x, y, output, n_elements, 3.14, 42, BLOCK_SIZE=16, num_warps=8
)
return output
@ -661,34 +667,44 @@ def forward(self, x):
self.assertIsNotNone(triton_node)
args = []
kwargs = []
kwargs = {}
for arg in triton_node.inputs:
if arg.kind == ArgumentKind.POSITIONAL:
args.append(arg.arg)
elif arg.kind == ArgumentKind.KEYWORD:
kwargs.append(arg.arg)
kwargs[arg.name] = arg.arg
self.assertEqual(len(args), 4)
self.assertEqual(len(kwargs), 5)
self.assertEqual(len(args), 6)
# Always: name, grid, output_indices and num_warps are
# Triton version dependent: num_cpu_threads, shared_memory_bytes
self.assertTrue(len(kwargs) >= 4)
for i in range(3):
self.assertIsNotNone(args[i].as_tensor)
self.assertEqual(args[3].as_int, 3)
self.assertEqual(kwargs[0].as_string, "add_kernel") # name
self.assertEqual(kwargs[1].as_ints, [1, 1, 1]) # grid
self.assertEqual(kwargs[2].as_ints, [2]) # output indices
self.assertAlmostEqual(args[4].as_float, 3.14, places=2)
self.assertEqual(args[5].as_int, 42)
kernel_name = kwargs["name"].as_string
symbol_name = kernel_name.rpartition("_")[0]
self.assertEqual(symbol_name, "add_kernel")
self.assertEqual(kwargs["grid"].as_ints, [1, 1, 1])
self.assertEqual(kwargs["output_indices"].as_ints, [2])
self.assertEqual(
kwargs[3].as_int, 8 if isinstance(m, MyModelAutotune) else 4
) # num warps
self.assertEqual(kwargs[4].as_int, 0) # shared mem bytes
kwargs["num_warps"].as_int, 8 if isinstance(m, MyModelAutotune) else 4
)
if "num_cpu_threads" in kwargs:
self.assertEqual(kwargs["num_cpu_threads"].as_int, 0)
if "shared_memory_bytes" in kwargs:
self.assertEqual(kwargs["shared_memory_bytes"].as_int, 0)
self.assertEqual(len(triton_node.outputs), 1)
self.assertIsNotNone(triton_node.outputs[0].as_tensors)
self.assertEqual(
len(triton_node.outputs[0].as_tensors), len(kwargs[2].as_ints)
len(triton_node.outputs[0].as_tensors),
len(kwargs["output_indices"].as_ints),
)
self.assertEqual(triton_node.outputs[0].as_tensors[0].name, "getitem")

View File

@ -675,7 +675,7 @@ class inner_f(torch.nn.Module):
# Verify buffer handling
buffer_count = 0
for desc, (node, grad_node) in input_grad_nodes.items():
for desc, (node, _grad_node) in input_grad_nodes.items():
if isinstance(desc, BufferAOTInput):
buffer_count += 1
self.assertIsNotNone(node)
@ -764,13 +764,13 @@ class inner_f(torch.nn.Module):
self.assertIn(node, named_params.values())
# Check that param_grads contains the same parameter nodes
for desc, (param_node, grad_node) in param_grads.items():
for desc, (param_node, _grad_node) in param_grads.items():
self.assertIn(param_node, param_nodes)
self.assertEqual(param_node, named_params[desc.target])
# Check that all_input_grads contains the parameter nodes
param_count = 0
for desc, (input_node, grad_node) in all_input_grads.items():
for desc, (input_node, _grad_node) in all_input_grads.items():
if isinstance(desc, ParamAOTInput):
param_count += 1
self.assertIn(input_node, param_nodes)

View File

@ -7555,7 +7555,7 @@ metadata incorrectly.
(_inp, _tg3),
]
for i, (inp_fn, tg_fn) in enumerate(TEST_CASES):
for inp_fn, tg_fn in TEST_CASES:
ref_x = inp_fn()
x = ref_x.detach().clone().requires_grad_()

View File

@ -3088,9 +3088,7 @@ class GraphModule(torch.nn.Module):
)
# Compare gradients for each layer
for i, (uncompiled_grad, compiled_grad) in enumerate(
zip(uncompiled_grads, compiled_grads)
):
for uncompiled_grad, compiled_grad in zip(uncompiled_grads, compiled_grads):
self.assertEqual(
uncompiled_grad,
compiled_grad,

View File

@ -282,7 +282,7 @@ class TestMin(TestCase):
# python 3.11 adapts bytecode after a number of iterations
# check that we still match names correctly
for i in range(10):
for _ in range(10):
f()
@skipIf(not TEST_CUDA, "no CUDA")

View File

@ -707,6 +707,50 @@ class TestConstFold(TestCase):
fold_result = mod_folded(in_x, in_y)
self.assertTrue(torch.equal(fold_result, base_result))
def test_fold_pure_subgraph(self):
class SubModule(torch.nn.Module):
def forward(self):
return torch.full((5, 10), 2.0) + 1
# Create a parent graph with this module as a subgraph and output
ep = torch.export.export(SubModule(), ())
parent_graph = torch.fx.Graph()
call_mod = parent_graph.call_module("sub", args=())
get_item = parent_graph.call_function(
operator.getitem, args=(call_mod, slice(None))
)
parent_graph.output((get_item,))
parent = torch.fx.GraphModule({"sub": ep.module()}, parent_graph)
mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(
parent, device_for_folded_attrs="cpu"
)
self._verify_const_fold_mod(mod_folded)
def test_do_not_fold_impure_subgraph(self):
"""
Skip folding any subgraph containing impure ops.
"""
class SubModule(torch.nn.Module):
def forward(self):
return torch.randn(5, 10) + 1
# Create a parent graph with this module as a subgraph and output
ep = torch.export.export(SubModule(), ())
parent_graph = torch.fx.Graph()
call_mod = parent_graph.call_module("sub", args=())
get_item = parent_graph.call_function(
operator.getitem, args=(call_mod, slice(None))
)
parent_graph.output((get_item,))
parent = torch.fx.GraphModule({"sub": ep.module()}, parent_graph)
mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(
parent, device_for_folded_attrs="cpu"
)
self.assertIsNone(mod_folded.const_subgraph_module)
if __name__ == "__main__":
raise_on_run_directly("test/test_fx.py")

View File

@ -491,9 +491,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
def ins_dense():
return torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0])
for i, (ins_fn, expected_fw_count) in enumerate(
zip([ins_sc, ins_dense], [2, 1])
):
for ins_fn, expected_fw_count in zip([ins_sc, ins_dense], [2, 1]):
reset_counter()
ref_out = fn(*ins_fn())
assert_counter(expected_fw_count, 0)
@ -524,16 +522,14 @@ def forward(self, arg0_1, arg1_1, arg2_1):
),
)
for i, (
for (
ins_fn_req_grad,
(
expected_fw_count,
expected_fw_count_after_bw,
expected_bw_count_after_bw,
),
) in enumerate(
zip([ins_dense_req_grad, ins_sc_req_grad], [(1, 1, 1), (2, 2, 2)])
):
) in zip([ins_dense_req_grad, ins_sc_req_grad], [(1, 1, 1), (2, 2, 2)]):
ref_ins = ins_fn_req_grad()
reset_counter()
ref_out = fn(*ref_ins)

Some files were not shown because too many files have changed in this diff Show More