Compare commits

..

316 Commits

Author SHA1 Message Date
b9caa336a0 Made everything unranked 2024-09-12 15:01:52 -07:00
dab7d646d5 Use a better decomposition for split_with_sizes (#135728)
This decomposition has less checks and improves the performance
of torch.compile.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135728
Approved by: https://github.com/ezyang
2024-09-12 16:38:51 +00:00
7647c398ff Allow optional positional arguments for torch.func.functional_call (#134643)
This PR resolves #134408. Add an additional test and have passed the local test.

Do you think we should add a post-check to ensure `args` and `kwargs` are not both `None`? It seems to be possible to have modules without inputs.

This PR does not include any such post-check.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134643
Approved by: https://github.com/zou3519
2024-09-12 15:22:06 +00:00
d67cc58181 [ONNX] Fix symbolic values and numpy implementation (#135786)
1. Remove `__eq__` to make `SymbolicTensor` hashable and test for that
2. Update the `__array__` method so that it works for tensor on GPU

Fixes https://github.com/pytorch/pytorch/issues/135700
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135786
Approved by: https://github.com/titaiwangms
2024-09-12 14:24:43 +00:00
dddaadac6c [dynamo] Dont graph break on inner torch.compile (#135819)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135819
Approved by: https://github.com/jansel
2024-09-12 11:39:09 +00:00
02169364e1 [inductor] Split reduction loops when there is no shared reads (#134307)
Fixes #129102

![image](https://github.com/user-attachments/assets/0d00f75b-2bb9-4ce6-a0d9-2daceaff539c)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134307
Approved by: https://github.com/shunting314
2024-09-12 09:45:08 +00:00
c30042fbeb [GPT-fast] Update compilation time target for Llama & Mixtral (#135817)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135817
Approved by: https://github.com/xmfan, https://github.com/huydhn
2024-09-12 07:13:44 +00:00
6700175531 [Inductor] simplify indexing_exprs in LoopBody._init_with_copy (#135574)
This PR uses `var_ranges` information to simplify `indexing_exprs` in `LoopBody._init_with_copy` to to reduce occurrences of `FloorDiv` and `ModularIndexing` in the `indexing_exprs`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135574
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
2024-09-12 06:56:34 +00:00
de8a8653c0 [dtensor][BE] replace compute_local_shape with compute_local_shape_and_global_offset (#135554)
**Summary**
1. This PR removes the public API `compute_local_shape` and replace its use with the more general API `compute_local_shape_and_global_offset`.
2. To keep `compute_local_shape_and_global_offset` consistent with `compute_local_shape` on empty shards, it now returns local tensor shape `(0,)` for empty shards which is more aligned with DTensor's semantics on non-participating ranks.

**Test**
`pytest test/distributed/_tensor/test_dtensor.py`
`pytest test/distributed/_tensor/test_init.py`
`pytest test/distributed/_tensor/test_tensor_ops.py`

Differential Revision: [D62415591](https://our.internmc.facebook.com/intern/diff/D62415591)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135554
Approved by: https://github.com/tianyu-l, https://github.com/wz337
2024-09-12 06:30:09 +00:00
86335e9135 [reland 3/3][fx] Bypass custom __setattr__ in Node.__init__ (#135735)
Relands #135079 whcih was reverted by #135562

I broke this up into three parts to test internally.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135735
Approved by: https://github.com/oulgen
2024-09-12 05:50:39 +00:00
14e3f3c062 [aoti] Remove nlohmann/json.hpp from header (#135765)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135765
Approved by: https://github.com/malfet
2024-09-12 05:38:51 +00:00
9852c6d236 xpu: fix 3rd party builds on systems with cmake<3.25 (#135767)
Cmake LINUX variable is available on starting from cmake 3.25. Better to use CMAKE_SYSTEM_NAME instead to relax cmake version requirement.

See: https://cmake.org/cmake/help/v3.25/variable/LINUX.html
Fixes: #135766
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135767
Approved by: https://github.com/malfet, https://github.com/guangyey
2024-09-12 05:31:01 +00:00
6354271178 [inductor] Skip unused call to get_estimated_runtime() (#135776)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135776
Approved by: https://github.com/oulgen
ghstack dependencies: #135445, #135446
2024-09-12 05:22:23 +00:00
12902f6ecf [inductor] Cache get_operation_names/get_buffer_names (#135446)
Before:
![image](https://github.com/user-attachments/assets/db5b6fce-d849-4512-a21d-7a09efc72311)

After:
![image](https://github.com/user-attachments/assets/097e340c-03b2-491e-ad36-132350b37892)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135446
Approved by: https://github.com/oulgen
ghstack dependencies: #135445
2024-09-12 05:22:23 +00:00
3decb676aa [inductor] Optimize cache_on_self (#135445)
This is a small compile time win, but also makes profiles more readable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135445
Approved by: https://github.com/oulgen
2024-09-12 05:22:23 +00:00
8d68a02905 OpenReg: Split the daemon into drvier/executor (#135646)
Split the daemon into a proper user-process driver vs device-process executor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135646
Approved by: https://github.com/albanD
2024-09-12 05:03:46 +00:00
28330a8a39 [reland 1/3][fx] Bypass custom __setattr__ in Node.__init__ (#135733)
Relands #135079 whcih was reverted by #135562

I broke this up into three parts to test internally.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135733
Approved by: https://github.com/oulgen
2024-09-12 04:29:37 +00:00
eaba287adb [dynamo] Bug fix for _torchdynamo_inline source handling (#135612)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135612
Approved by: https://github.com/drisspg
2024-09-12 04:05:08 +00:00
cyy
f5f1d0a753 Fix build warnings for torch_python (#134981)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134981
Approved by: https://github.com/ezyang
2024-09-12 03:59:34 +00:00
5bc238c73e torch.hub: add get_dir/set_dir type hints (#134906)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134906
Approved by: https://github.com/Skylion007
2024-09-12 03:53:29 +00:00
79223114db Avoid inserting extra transpose when the input to group norm is NHWC (#135575)
When the input format for group norm is NHWC and the device is privateuseone, it introduces an additional transpose operation. To avoid this issue, a check for the privateuseone device needs to be added here.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135575
Approved by: https://github.com/ezyang
2024-09-12 03:36:05 +00:00
cyy
7cfd23636c Fix clang-tidy warnings in Caffe2 code (#134935)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134935
Approved by: https://github.com/ezyang
2024-09-12 03:27:09 +00:00
0d1d69fd25 Update torch-xpu-ops pin (ATen XPU implementation) (#135647)
Release cycle for PyTorch 2.5
1. Fixing runtime error on Windows: Fail to load torch_xpu_ops_unary_binary_kernels.dll as the bin size is large.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135647
Approved by: https://github.com/EikanWang
2024-09-12 03:16:08 +00:00
21a64d57b1 [BE] typing for decorators - masked/_ops (#135108)
Differential Revision: D62184735

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135108
Approved by: https://github.com/Skylion007
2024-09-12 01:34:09 +00:00
1a74952925 "Remove BLOCK_LIST" (#135729)
Summary:
Skip test_prepare_qat_conv_bn_fusion_getitem_placeholder when we use training ir, since it's only for bn-getitem pattern, but the pattern doesn't exist in training ir.

Remove BLOCK_LIST since it's empty.
Now all internal unittests will use training ir.

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan'  caffe2/test/quantization:test_quantization -- -r test_prepare_qat_conv_bn_fusion_getitem_placeholder
buck2 run 'fbcode//mode/dev-nosan'  caffe2/test:quantization_pt2e_qat -- -r test_prepare_qat_conv_bn_fusion_getitem_placeholder
```

Differential Revision: D62387987

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135729
Approved by: https://github.com/tugsbayasgalan
2024-09-12 01:22:06 +00:00
a130ed828a Fix the upload of x86 micro benchmark results (#135780)
Upload stats workflow currently skips this https://github.com/pytorch/pytorch/actions/runs/10807251335/job/29977650639, this is a miss from https://github.com/pytorch/pytorch/pull/135042.  So, the workflow is running but nothing has been uploaded yet.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135780
Approved by: https://github.com/atalman
2024-09-12 01:16:38 +00:00
eb0fe02933 [PT2][inductor][Optimus] Add pad_aten_mm_pass pattern to resolve long computation kernel in LCE (#135167)
Summary:
We observed another long computation issue for OBA_AFOC pyper model, thus adding a pattern to avoid the perf regression

- Only happens in A100
- Do not want to use force_shape_pad since it will pad all GEMMs, which may not be optimal. Optimus pass has more flexisibility to customized GEMM shape and do corresponding padding
- To enable, we pass the pass to config, where "k_threshold_to_pad" can be customized

inductor_config.patch(post_grad_fusion_options={"pad_aten_mm_pass": {"k_threshold_to_pad" : 8388608}})

Test Plan:
# unit test

```
buck2 test mode/opt //caffe2/test/inductor:pad_mm
```
Buck UI: https://www.internalfb.com/buck2/58b0f272-f405-45be-bc8d-aec2dc4d5841
Test UI: https://www.internalfb.com/intern/testinfra/testrun/10133099209954651
Network: Up: 9.0KiB  Down: 142B  (reSessionID-8eb71a37-a5ca-4aff-a4f1-93ade3e47e4e)
Jobs completed: 9. Time elapsed: 3:18.0s.
Cache hits: 0%. Commands: 3 (cached: 0, remote: 0, local: 3)
Tests finished: Pass 17. Fail 0. Fatal 0. Skip 0. Build failure 0

# e2e test
see [D62388582](https://www.internalfb.com/diff/D62388582)

Differential Revision: D62220158

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135167
Approved by: https://github.com/jackiexu1992
2024-09-12 00:51:34 +00:00
d270e2d240 [FSDP2] better error msg for cpu offloading (#135156)
when cpu offloading is enabled, if user load a gpu state dict, FSDP2 will throw a less obvious error at backward
```
RuntimeError: attempting to assign a gradient with device type 'cpu' to a tensor with device type 'cuda'. Please ensure that the gradient and the tensor are on the same device
```

this PR throws error more explicitly by specifying which parameters should be moved because of cpu offloading

```
FSDP parameters should be materialized on cpu when enabling cpu offloading. For example, load cpu state dict or call module.to_empty(device="cpu"). Found following parameters on non-cpu device: ['0.weight']
```

`pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135156
Approved by: https://github.com/awgu
2024-09-12 00:05:07 +00:00
16b37b309f [Inductor] Rename cpp_wrapper_cuda.py as cpp_wrapper_gpu.py (#135313)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135313
Approved by: https://github.com/jansel, https://github.com/desertfire
ghstack dependencies: #135312
2024-09-11 23:59:54 +00:00
13ee85ca5e [Inductor] Generalize cuda cpp wrapper as common triton based GPU cpp wrapper, will be reused by xpu in next PR. (#135312)
[Inductor] Generalize cuda cpp wrapper as common triton based GPU cpp wrapper, will be reused by xpu in next PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135312
Approved by: https://github.com/jansel, https://github.com/desertfire, https://github.com/eellison
2024-09-11 23:59:54 +00:00
94d2471d1f [Traceable FSDP2] Use .copy_ instead of .set_ for unsharded_param inplace update; Replace unsharded_param graph input usage with graph intermediate; Support FSDP2+LoRA (#133730)
Using `fsdp.set_` for unsharded_param inplace update causes difficult-to-debug errors when enabling Traceable FSDP2 on TorchTune models. In this PR, we change it to use `fsdp.copy_` which fixes the error and also strictly follows eager semantics (i.e. if user explictly stores an alias of the unsharded_param during execution of the user's module code, that alias will get updated correctly when the unsharded_param is copy_ into; whereas if we just swap out unsharded_param storage via set_, that user-saved alias will not get updated, which is not good).

This PR also implements the graph pass to remove the resizes and copy if there is a resize_(full) -> copy_ -> resize_(0) pattern.

------

Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_trace_fsdp_copy_`
- `pytest -rA test/dynamo/test_repros.py::ReproTests::test_partitioner_cse_respects_mutation_boundaries`
- `pytest -rA test/dynamo/test_repros.py::ReproTests::test_fsdp_set_input_mutation_applied_when_input_gets_no_gradients`
- `pytest -rA test/inductor/test_pattern_matcher.py::TestPatternMatcher::test_mutation_op_matching`
- `python test/inductor/test_distributed_patterns.py DistributedPatternTests.test_fake_distributed_aot_eager`
- `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=1 PYTORCH_TEST_WITH_CROSSREF=1 python test/functorch/test_aotdispatch.py TestEagerFusionOpInfoCPU.test_aot_autograd_exhaustive_norm_cpu_float32`
- `python test/distributed/test_inductor_collectives.py TestCollectivesInductor.test_backwards`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133730
Approved by: https://github.com/bdhirsh
2024-09-11 23:01:05 +00:00
5ca46be15e Fix/torch cat doc attr (#135698)
The `torch.cat` attr name for tensors in the docs differs from the method signature, unlike other methods.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135698
Approved by: https://github.com/albanD

Co-authored-by: Alexander Jipa <azzhipa@amazon.com>
2024-09-11 22:32:55 +00:00
9a04cfbeff fix for fp16 (#134106)
This PR is a replacement for https://github.com/pytorch/pytorch/pull/133085 for pushing a quick fix for RMSNorm.
The original author is @kkontny

Previous PR summary:
Since FP16 has quite small dynamic range it is very easy to overflow while computing `at::pow(input, 2)` , and it happens in real world computation.

I've tried to use `nn.RMSNorm` fused implementation instead of `LlamaRMSNorm` inside `transformers` implementation of Llama (`src/transformers/models/llama/modeling_llama.py`). It started to give wrong answers in Fp16 while still giving good in FP32. I figured out happens due to overflow while computing square of the input tensor.

Original `LLamaRMSNorm` implementation upcasts input to fp32 to prevent this and give better numerical stability.

```
class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)
```

Proposed commit fixed the issue. FP16 in RMSNorm has to be treated in special way, to be usable in real world implementations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134106
Approved by: https://github.com/mikaylagawarecki, https://github.com/eqy
2024-09-11 22:02:07 +00:00
66db61f0d1 [ONNX] Update fake mode usage in onnx docs (#135512)
Update fake mode usage in onnx docs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135512
Approved by: https://github.com/justinchuby

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
2024-09-11 21:29:04 +00:00
c025f7becc Revert "[Partitioner] Reuse partition to check whether nodes exist (#135317)"
This reverts commit e004d539da3335d97a8134c9081245628f18eb67.

Reverted https://github.com/pytorch/pytorch/pull/135317 on behalf of https://github.com/izaitsevfb due to BC-breaking, breaks executorch and internal meta builds ([comment](https://github.com/pytorch/pytorch/pull/135317#issuecomment-2344730294))
2024-09-11 21:27:53 +00:00
8c4e1148b8 Refactoring byte_order (#135558)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135558
Approved by: https://github.com/mikaylagawarecki
2024-09-11 21:06:43 +00:00
e20ee39558 Expand bitwise ops to unsigned types (#135525)
Fixes https://github.com/pytorch/pytorch/issues/135436

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135525
Approved by: https://github.com/ezyang
2024-09-11 20:48:52 +00:00
74fd1bf965 [ROCm] Update to AOTriton 0.7b (#134498)
Notable changes:
1. Enable CudaGraph related tests
2. Fix UT problems
3. EXPERIMENTAL Navi31 support. User should enable Navi31 support with Env Var `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1`

Know Problem:
1. `test/test_transformers.py` will massive failures and/or NaN outputs with `--use-pytest`
    + Update: Confirmed skip `class TestSDPAPrivateUse1Only` can fix the problem with `--use-pytest`

Note:
AOTriton 0.7b adds support to nestedtenosrs+SDPA but need more work (and consequently a separate PR) to enable it.

Fixes #133540

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134498
Approved by: https://github.com/pruthvistony, https://github.com/jeffdaily, https://github.com/malfet
2024-09-11 20:34:01 +00:00
5d964a5eb7 [Export] Fix SDPA decomposition (#135297)
Summary: Update SDPA decomposition to match updated stride from D62009189 which aligns strides with the `aten._scaled_dot_product_attention_math.default`, which makes `t.permute().continuous().permute()` no longer necessary.

Test Plan: CI

Differential Revision: D62278378

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135297
Approved by: https://github.com/drisspg
2024-09-11 20:21:59 +00:00
118d7e1480 [Inductor] add _dynamo.reset to test_cat_slice_cat_cuda (#135694)
Summary: test_cat_slice_cat_cuda runs inductor multiple times and check counters["inductor"] in between, and thus we need to reset properly.

Differential Revision: D62500331

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135694
Approved by: https://github.com/masnesral
2024-09-11 20:07:11 +00:00
dd47f6f623 Simplify expr before getting implications in _maybe_evaluate_static (#135499)
Fixes #134268

Previously we weren't simplifying these expressions before calling get_implications, resulting in inconsistent application of FloorDiv/CleanDiv. See #134268  for more details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135499
Approved by: https://github.com/ezyang
2024-09-11 19:48:29 +00:00
e05ea2b179 Add decomposition for transpose_copy (#130943)
* Extracted from #128416
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130943
Approved by: https://github.com/amjames, https://github.com/eellison
2024-09-11 19:45:22 +00:00
ad75b09d89 Replace capture_pre_autograd_graph with export_for_training in torch tests (#135623)
Summary: as title

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_conv_dynamic
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:fx -- -r matcher
 buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r x86
```

CI

Differential Revision: D62448302

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135623
Approved by: https://github.com/tugsbayasgalan
2024-09-11 19:23:08 +00:00
a2cb9b7331 Flip triton kernel default layout constraint to "needs_fixed_stride_order" (#135581)
This is to match the default layout constraint for custom operators. By
default, Inductor should match the stride order of inputs to a triton
kernel.

Test Plan:
- existing tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135581
Approved by: https://github.com/eellison
ghstack dependencies: #135530
2024-09-11 18:43:18 +00:00
451eaf0ff2 Log full exception trace when error raised in Dynamo (#135697)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135697
Approved by: https://github.com/Skylion007
2024-09-11 18:14:33 +00:00
09519eb195 Support rolling over a percentage of workflows (#134816)
In order to support adding a rollover percentage, this ended up being a complete rewrite of runner_determinator.py.

Details of the new format are in the comments up top.

On the plus side, this now includes some unit tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134816
Approved by: https://github.com/PaliC, https://github.com/zxiiro
2024-09-11 18:01:26 +00:00
5314ae2660 Don't use exception chaining for BackendCompilerFailed (#135545)
Commandeered from https://github.com/pytorch/pytorch/pull/135496 as I'm now helping @ezyang ship dynamic float arguments in PT2.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135545
Approved by: https://github.com/ezyang
2024-09-11 17:49:18 +00:00
da587de9cb [ROCm] [BUGFIX] Re-enable rocm-specific tuning parameters v2 (#133852)
Small bug fix - https://github.com/pytorch/pytorch/pull/124592 replaced the torch.version.hip with device_props but made a mistake in porting the original logic.

The original code was:
`if torch.version.hip is not None:`

Which was incorrectly replaced by:
`if self.device_props.type != "hip":`

Another occurence of https://github.com/pytorch/pytorch/pull/130617

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133852
Approved by: https://github.com/masnesral, https://github.com/malfet
2024-09-11 17:21:40 +00:00
82a4df2d5f [CI] [ROCm] Run rocm workflow on every push to main branch (#135644)
Dial the frequency back up from https://github.com/pytorch/pytorch/pull/131637

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135644
Approved by: https://github.com/huydhn
2024-09-11 17:21:05 +00:00
18a9030952 [CI] Fix update slow tests (#135390)
* Add pytorchbot to list of approvers for file
* Add labels to the auto created PR

The auto generated PR is currently not merging due to some failing tests on slow workflow that were supposed to be moved back to normal

idk if this has much value, clearly we've been managing without the update
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135390
Approved by: https://github.com/ZainRizvi
2024-09-11 17:02:17 +00:00
03f23d07b4 Optimize ShapeEnv.replace (#135652)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135652
Approved by: https://github.com/ezyang
ghstack dependencies: #135621, #135622
2024-09-11 16:50:59 +00:00
8c738c9270 Improve performance of sympy_generic_le (#135622)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135622
Approved by: https://github.com/ezyang
ghstack dependencies: #135621
2024-09-11 16:20:03 +00:00
7ddacaf40a Improve performance of canonicalize_bool_expr (#135621)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135621
Approved by: https://github.com/ezyang
2024-09-11 16:20:03 +00:00
183c32fd3b Revert "[Dynamo] Trace torch function modes entered outside of torch.compile (#133137)"
This reverts commit 0d15122092c27fec1143b800bab7c996d126b547.

Reverted https://github.com/pytorch/pytorch/pull/133137 on behalf of https://github.com/clee2000 due to something in this stack broke functorch/test_control_flow.py::TestControlFlow::test_scan_simple_graph [GH job link](https://github.com/pytorch/pytorch/actions/runs/10804912306/job/29980571390) [HUD commit link](444b52ff40), newly added test yesterday ([comment](https://github.com/pytorch/pytorch/pull/133137#issuecomment-2344054339))
2024-09-11 15:57:00 +00:00
3ab12e2596 Revert "[Dynamo] Support thread local setattr (#135443)"
This reverts commit 160c228a4bd60ceffa62b045a6b0a6f9413835c5.

Reverted https://github.com/pytorch/pytorch/pull/135443 on behalf of https://github.com/clee2000 due to something in this stack broke functorch/test_control_flow.py::TestControlFlow::test_scan_simple_graph [GH job link](https://github.com/pytorch/pytorch/actions/runs/10804912306/job/29980571390) [HUD commit link](444b52ff40), newly added test yesterday ([comment](https://github.com/pytorch/pytorch/pull/135443#issuecomment-2344042800))
2024-09-11 15:53:55 +00:00
596e93b506 Revert "[dynamo] Bug fix for _torchdynamo_inline source handling (#135612)"
This reverts commit 5c3d0a2dedbc0e85f3b256ce56ac674078a5fae1.

Reverted https://github.com/pytorch/pytorch/pull/135612 on behalf of https://github.com/clee2000 due to broke inductor/test_cpu_select_algorithm.py::TestSelectAlgorithmCPU::test_linear_input_transpose_bias_True_cpu_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/10805518363/job/29982386304) [HUD commit link](5c3d0a2ded), bad TD ([comment](https://github.com/pytorch/pytorch/pull/135612#issuecomment-2344039370))
2024-09-11 15:51:12 +00:00
f96e8041b1 Revert "[Dynamo] Simplify torch function mode stack guard (#135444)"
This reverts commit 444b52ff40cf4afce7bc3fdcf021a88eab3b954c.

Reverted https://github.com/pytorch/pytorch/pull/135444 on behalf of https://github.com/clee2000 due to something in this stack broke functorch/test_control_flow.py::TestControlFlow::test_scan_simple_graph [GH job link](https://github.com/pytorch/pytorch/actions/runs/10804912306/job/29980571390) [HUD commit link](444b52ff40), newly added test yesterday ([comment](https://github.com/pytorch/pytorch/pull/135444#issuecomment-2344036843))
2024-09-11 15:48:27 +00:00
7cf9c81918 Revert "[Dynamo] Use custom backend to reenter metadata tf mode when tracing while/cond (#134732)"
This reverts commit 6a3edfcc1e474e6ebd0c06624000a6d6bf1a0dee.

Reverted https://github.com/pytorch/pytorch/pull/134732 on behalf of https://github.com/clee2000 due to broke functorch/test_control_flow.py::TestControlFlow::test_scan_simple_graph [GH job link](https://github.com/pytorch/pytorch/actions/runs/10804912306/job/29980571390) [HUD commit link](444b52ff40), newly added test yesterday ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2344016694))
2024-09-11 15:39:21 +00:00
49e0b88aab Fix test_triton_kernel_float64_constant (#135583)
Summary: Landed https://github.com/pytorch/pytorch/pull/135260 too soon and the test in that PR doesn't do exactly what I tested (actually test different dtypes).

Test Plan: `python test/inductor/test_triton_kernels.py -k float64_constant`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135583
Approved by: https://github.com/isuruf, https://github.com/eellison, https://github.com/Skylion007
2024-09-11 15:16:23 +00:00
ee8c5cc1cc For S444023: Back out "deprecate search_autotune_cache (#133628)" (#135186)
Summary: For S444023

Test Plan:
Revert prevented the NaN errors - f639391901
Training job ran for 7767 iterations. NaN errors show up within the first 1k.

Reviewed By: nmacchioni

Differential Revision: D62224747

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135186
Approved by: https://github.com/kit1980
2024-09-11 14:08:40 +00:00
ce4d146f56 ATen | Fix MPSCNNNeuron creation on Mac Catalyst. (#135595)
Summary:
These are still utilized directly when using relu/sigmoid/tanh tensors directly from here: https://fburl.com/code/k6n7ofzd
However, on Mac Catalyst we always were returning `nil`, as such in most cases yielding the entire graph completely useless and most often just stray `MPSTemporaryImage` references that were never written into.

This fixes the issue completely by making sure that we always return the valid kernels back, so they can be executed.

Test Plan: Test with segmentation net that uses a combination of relu and other tensors together - run this via Mac Catalyst build - it works! {F1858576745}

Reviewed By: MichaelTay

Differential Revision: D62430010

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135595
Approved by: https://github.com/MichaelTay
2024-09-11 11:12:23 +00:00
0226fcaacf Disable cuda specific restrictions in _scaled_mm for other devices (#135579)
Fixes #135576

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135579
Approved by: https://github.com/drisspg
2024-09-11 11:05:38 +00:00
4cde5096c4 [Inductor][FlexAttention] Supports dynamic shapes with block mask (#135629)
Fixes #134560 and #135206

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135629
Approved by: https://github.com/drisspg
2024-09-11 08:10:50 +00:00
443c015393 [Distributed] Improve efficiency of NaN checker (#135414)
Some customers would like to run the NaN checks on the fly, so we are improving its efficiency.

## Benchmarking
Allreduce 2G floats. `TORCH_NCCL_NAN_CHECK=1`
Red kernel: ncclAllreduce
Blue kernel: Nan check

<img width="1093" alt="Screenshot 2024-09-06 at 10 00 05 PM" src="https://github.com/user-attachments/assets/5501bc31-024f-4115-adb2-dd66eb4025d3">

## Comparison with torch ops:
Let's say a user manually check for NaNs with the following torch ops before all-reduce:
```
torch.any(torch.isnan(x))
```
<img width="1091" alt="Screenshot 2024-09-06 at 10 14 53 PM" src="https://github.com/user-attachments/assets/1f8b5f63-c955-4612-bb96-241b6c69959b">

So our perf is on-par with torch ops.

## Changes
- Load from vidmem using "big packs" of 16 bytes
- Bump `blockDim.x` from 256 to 512
- Separate loads and checks into two loops, each of 8 iterations
- Unroll the loops
- Templated functions for checking NaN in a "big pack" based on dtype

Special thanks to @jbachan from NCCL!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135414
Approved by: https://github.com/wconstab
2024-09-11 07:53:42 +00:00
4ae6d7c18f Back out "[pytorch][PR] [export] fix re-export custom metadata" (#135634)
Summary: Broke some tests. Revert this diff

Test Plan: CI

Differential Revision: D62474337

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135634
Approved by: https://github.com/tugsbayasgalan
2024-09-11 06:16:26 +00:00
3084b7b5c0 [cuDNN][SDPA] Support attn_bias in cuDNN (#130482)
CC @drisspg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130482
Approved by: https://github.com/drisspg, https://github.com/Skylion007, https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2024-09-11 05:59:25 +00:00
5c3d0a2ded [dynamo] Bug fix for _torchdynamo_inline source handling (#135612)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135612
Approved by: https://github.com/drisspg
ghstack dependencies: #135588
2024-09-11 05:23:42 +00:00
c608b17f60 [PTD][BE][c10d] Add some code documents for TCPStore code and cosmetic changes to libUVStore code (#130496)
While designing something else when TCPStore is needed. I spent some time digging into the codebase of TCPStore and found that the code is a little bit challenging to understand without proper documents. Although people from OSS community must be smarter than me, I still want to document my findings in the code so that devs and users can use them as a reference down the road.

Also for libuv, we need to make private variables with a "_", so it's a pure renaming of private variables such as `tcpServer`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130496
Approved by: https://github.com/wconstab
2024-09-11 04:42:25 +00:00
444b52ff40 [Dynamo] Simplify torch function mode stack guard (#135444)
The semantics of ignored modes previously had edge cases, this eliminates these by in essence filtering any ignored modes out of both the ref stack and the current torch function mode stack. This is purely to fix complexity in #135422.  The ignored modes handling will be removed in a future PR after https://github.com/pytorch/pytorch/pull/135422 lands, since we will then trace through DeviceContexts vs inserting them into the graph which needed these extra workarounds for correctness.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135444
Approved by: https://github.com/anijain2305, https://github.com/williamwen42
ghstack dependencies: #134732, #133137, #135443
2024-09-11 04:18:22 +00:00
160c228a4b [Dynamo] Support thread local setattr (#135443)
In preparation for tracing through DeviceContext (defb515306/torch/utils/_device.py (L66))
This PR adds support for calling the setattr of thread local objects. These objects have a slots impl, and since this doesn't appear to have any side effects, we call this setattr impl when replaying mutations, since calling `object.__setattr__` on these objects results in a type error.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135443
Approved by: https://github.com/anijain2305
ghstack dependencies: #134732, #133137
2024-09-11 04:18:22 +00:00
0d15122092 [Dynamo] Trace torch function modes entered outside of torch.compile (#133137)
This PR adds initial tracing for torch function modes.

Details:
In essence, this adds tracing into the torch function of modes entered outside of the torch.compile call.
This does not yet support tracing enter/exit of a torch function mode/ tracing set_default_device properly using the new mode infra (this will be a very good stress test for modes). I am adding more PRs to this stack to support these. The overall plan is to support tracing enter/exit and handling graph breaks like we do other torch.* context managers.

Previously landed:
https://github.com/pytorch/pytorch/pull/133135
https://github.com/pytorch/pytorch/pull/133136
https://github.com/pytorch/pytorch/pull/133134
https://github.com/pytorch/pytorch/pull/133133
https://github.com/pytorch/pytorch/pull/133132
https://github.com/pytorch/pytorch/pull/133131
https://github.com/pytorch/pytorch/pull/133729
https://github.com/pytorch/pytorch/pull/133130

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133137
Approved by: https://github.com/jansel, https://github.com/zou3519
ghstack dependencies: #134732
2024-09-11 04:18:22 +00:00
6a3edfcc1e [Dynamo] Use custom backend to reenter metadata tf mode when tracing while/cond (#134732)
For tracing cond/while in eager, we trace the HOP with the eager backend with metadata torchfunction mode enabled. HOPs disallow the mutation that occurs in this torch function mode, so it is not able to be traced. As a result, we use a custom backend which enters this mode for tracing these HOPs. Thanks to @ydwu4 for the help with implementing this

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134732
Approved by: https://github.com/ydwu4
2024-09-11 04:18:22 +00:00
356f14e7b7 Fix the output of FileCheck when not run and add unit tests (#135345)
When FileCheck is destructed without execution, it should output all rules.
For example:
```
>>> fc = FileCheck().check("test")
>>> del fc
You have not run this instance of FileCheck!
FileCheck checks:
        CHECK: test
```

Additionally, unit tests for the Python interface of FileCheck will be added.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135345
Approved by: https://github.com/eellison
2024-09-11 04:13:24 +00:00
34dc8f69a1 Adding entry-point based support for out-of-tree rendezvous plugins (#132633)
Fixes #127519

Currently in torchrun rendezvous, there are only two rendezvous backends supported out of the box: `C10d` and `Etcd`. The changes in this PR enables the distributed elastic users to bring their out-of-tree rendezvous backend implementations as Python packages.

#### AUTHORING NEW PLUGIN
Any new plugin will be a python package exposing entry-points. For example, the structure of redis plugin is as follows:

```
plugin_root
|_ pyproject.toml
|_ src
   |_ redis
      |_ __init__.py
      |_ redis_store.py
      |_ redis_backend.py
```

The contents of the `pyproject.toml` should indicate that this is exposes a torchrun entry-point by mentioning the group name `torchrun.plugins`. The `pyproject.toml` for redis plugin would be as follows:

```
[project]
name = "redis"
version = "0.0.1"

[project.entry-points.'torchrun.plugins']
redis = 'redis'
```

The `src/redis/__init__.py` file would contain functions that return the plugin name and plugin handler. The contents of `__init__.py` for redis would be as follows:

```
def getPluginHandler():
    def _create_redis_handler(params: RendezvousParameters):
        from redis_rendezvous_backend import create_backend
        backend, store = create_backend(params)
        return create_handler(store, backend, params)
    return _create_redis_handler
```

The files `redis_store` and `redis_backend` contain the implementation of [Store](41189b0da4/torch/_C/_distributed_c10d.pyi (L171)) and [RendezvousBackend](e782918b8e/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py (L61)) respectively.

#### USER EXPERIENCE
Before using the plugin for the first time, the user has to install the plugin packages. For example, the published packages can be installed using `pip3 install <plugin-name>` and the plugin is in local file systemcan be installed using `pip3 install -e <plugin-location>`.

Once installed, the new backend can be used in torchrun as follows:

```
torchrun --rdzv-backend=redis --rdzv-endpoint=redis-container:6379 --nnodes=3 --nproc-per-node=1 --max-restarts=3 --rdzv-id=1 test.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132633
Approved by: https://github.com/fduwjj
2024-09-11 03:35:02 +00:00
cd9ee49a69 [aoti] Add cpp loader (#135374)
* Added a cpp loader, AOTIModelPackageLoader, which can load the .pt2, build the .so, and create a runner. The python-facing API is that users can directly call the `run` function, whereas in cpp users can directly access the `runner_` if they are more familiar with that. I couldn't figure out how to bind the `get_runner()` function to python...
* Added a new config, `aot_inductor.package_cpp_only` which will **not** package the so. This means that whenever the package is loaded, we will need to build the so. This is turned off by default so that new environments do not need to rebuild their so. The `package_cpp_only` is a feature which torchchat intends to use to provide flexibility to users.
* Added a new config, `aot_inductor.metadata` which stores user-provided metadata, serialized to the pt2 as a json file. It also stores the device used when exporting, "cuda" or "cpu", so that during load time, we can use that data to determine which AOTIModelContainerRunner to use. The metadata can be accessed through `loader.get_metadata()`. TODO is to move this metadata to the toplevel `package_aoti` function so that we can remove the metadata as a config.
* Separated out `package_aoti` as a standalone function, instead of it automatically being called in inductor. This is to prepare for the case where users will compile multiple models, and want to bundle it in one package. The specific use case is in torchchat, where we want to package the separately-exported encoder and decoder layers. An example of how to use this is in `test_multiple_methods`.
* `load_package` will load a singular model, given the model name.
* The loader doesn't support windows for now, I think I need to add some more casing to make the build commands work on windows?

Differential Revision: [D62329906](https://our.internmc.facebook.com/intern/diff/D62329906)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135374
Approved by: https://github.com/desertfire, https://github.com/malfet
2024-09-11 03:00:01 +00:00
26e5572dd2 Bump triton xpu pin and release version (#135638)
Similar with https://github.com/pytorch/pytorch/pull/135627

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135638
Approved by: https://github.com/atalman
2024-09-11 00:56:15 +00:00
693897df42 [dynamo] Missing guard source keys for corner case of NNModuleVariabl… (#135041)
Potentially fixes - https://fb.workplace.com/groups/1286739428954016/permalink/1319662695661689/

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135041
Approved by: https://github.com/ezyang
2024-09-11 00:43:26 +00:00
3bf6be457d [MPS] Add missing dispatch to rshift.Tensor (#135607)
Missed it while working on https://github.com/pytorch/pytorch/pull/131813
Test plan: `python -c "import torch;print(torch.randint(100, 500, (64,), device='mps') >> torch.tensor([3,], device='mps'))"`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135607
Approved by: https://github.com/manuelcandales
2024-09-11 00:20:53 +00:00
492f064f15 [ONNX] Add assertion nodes to ignoring list (#135591)
Fixes #135419

PS: there are 104 empty output nodes, I suggest we add them one by one when we run into them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135591
Approved by: https://github.com/justinchuby
2024-09-11 00:18:17 +00:00
29408ea81a Add option to tweak inductor stride settings for user-defined triton kernels (#135530)
Previously, Inductor was allowed to modify the stride/storage_offset
(layout) for inputs to user-defined triton kernels. This can cause
silent incorrectness because most triton kernels are written for a
specific striding pattern (usually contiguous).

This PR adds a config to allow the user to choose Inductor's behavior on
this. The options are:
- "flexible_layout" (default): Inductor can modify the layout for inputs
  to user-defined triton kernels as much as it wants.
- "needs_fixed_stride_order": Inductor must preserve the stride order
  (when compared to tracing) for inputs to user-defined triton kernels.

This matches our handling for custom operators. In the future, we'll
want a "needs_exact_strides" option (this is the safest option).

Test Plan:
- new test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135530
Approved by: https://github.com/FindHao, https://github.com/oulgen
2024-09-11 00:11:17 +00:00
02dcb07765 Add boolean support in pack segments ops for both cpu and cuda impls (#132897) (#135620)
Summary:

Same as int types, forward only.

bypass-github-export-checks diff has been synced to github

Test Plan:
buck test mode/dev-nosan //caffe2/torch/fb/sparsenn:test -- test_pack_segments
https://www.internalfb.com/intern/testinfra/testconsole/testrun/16888498646804437/

Reviewed By: garroud

Differential Revision: D60785563

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135620
Approved by: https://github.com/kit1980

Co-authored-by: Haoming Lu <haominglu@meta.com>
2024-09-11 00:03:17 +00:00
5c38aa72c0 [dynamo][dicts][nv-embed] Support update with kwargs (#135588)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135588
Approved by: https://github.com/yanboliang
2024-09-10 23:50:23 +00:00
5134ba7458 Bump triton pin and release version (#135627)
Update the pin and release version to sync with https://github.com/triton-lang/triton/tree/release/3.1.x

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135627
Approved by: https://github.com/Chillee, https://github.com/drisspg, https://github.com/malfet
2024-09-10 23:46:36 +00:00
e48ee2cf50 [ONNX] Fix scaled_dot_product_attention with float scale (#135594)
Fixes #125158

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135594
Approved by: https://github.com/justinchuby
2024-09-10 23:04:02 +00:00
eb38ee21ba [ROCm] slow torch.sum optimization by increasing max_values_per_thread in reduce config (#135397)
Fixes #132964

This change is to optimize torch.sum() performance by increasing max_values_per_thread in setReduceConfig() for ROCm platform.
By increasing this parameter, it uses fewer threadblocks and improved the performance.

Test:
Tested on MI300x and H100, and now the MI300x perf improved to 3205GByte/s from ~1690GByte/s for the test case and is slightly better than H100 (3136GByte/s).

Also tested with other different sizes of tensors and also see perf improvement.

```python
import torch
from triton.testing import do_bench

x = torch.randn(2**30, device='cuda')

ms = do_bench(lambda: x.sum(dim=-1))

bandwidth_gbyte = x.numel() * x.dtype.itemsize / (10**9)

time_s = ms / 1000

bw_per_second = bandwidth_gbyte / time_s

print(bw_per_second)
```

Co-author: @carlobertolli

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135397
Approved by: https://github.com/eqy, https://github.com/malfet
2024-09-10 21:03:01 +00:00
8057b72763 [ez][inductor] don't benchmark cloning if there are no mutated args (#135533)
When a kernel does not have mutated args (this is quite common?), benchmarking the cost of cloning actually benchmarks a no-op. This still takes >100ms since triton.testing.do_bench will allocate 100 ms budget to run the kernel.
Skipping this benchmarking can save quite some compilation time if the code path is hit multiple times. Let's say, if the code path is hit 100 times when the graph is large, we would save >10s.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135533
Approved by: https://github.com/jansel
ghstack dependencies: #135531
2024-09-10 20:54:31 +00:00
7b17918dc9 [inductor] fix a device sync issue for benchmarking fusion (#135531)
Fix https://github.com/pytorch/pytorch/issues/134768 .

When we benchmark the latency for a fused node set, we do benchmarking twice:
1. benchmark the latency of the kernel including cloning mutated args
2. benchmark the latency of cloning mutated args without running the kernel

We subtract result 2 from result 1 to get the latency of the kernel itself.

But when the tensors are not on the cuda device 0, we get equal number for result 1 and result 2 no matter how much work the kernel does. The root cause is, in `triton.testing.do_bench` the `torch.cuda.synchronize` call sync the current cuda device (which is device 0 if it's not overriden). But since the tensors and kernels are located on another device, the sync actually does nothing (unless there happens to be other kernels on the device 0).

The fix is to set the correct current device in our benchmarking code.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135531
Approved by: https://github.com/jansel
2024-09-10 20:54:31 +00:00
66c45f3ed9 [export] fix re-export custom metadata (#135282)
Fixes #134778

When a model is exported and debug handles are added to the "custom" field of non-placeholder and non-output nodes in the graph, re-exporting it will change the metadata of placeholder nodes (the "custom" field will be added or copied to these nodes, depending whether `ExportedProgram` or `ExportedProgram.module()` is passed to `generate_numeric_debug_handle()`).

This occurs because when we re-export the model, `placeholder` nodes are unlifted to `get_attr` nodes. These nodes remain as `get_attr` after being exported to `gm_torch_level`.  Their metadata are modified [here](https://github.com/pytorch/pytorch/blob/main/torch/export/_trace.py#L1347) based on `params_buffers_to_node_meta` which is collected [here](https://github.com/pytorch/pytorch/blob/main/torch/export/_trace.py#L1312).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135282
Approved by: https://github.com/jerryzh168, https://github.com/zhxchen17, https://github.com/tugsbayasgalan
2024-09-10 20:15:02 +00:00
0a9d55d2ee Revert "[AOTI] Fix assert_function call in cpu autotune template (#135086)"
This reverts commit 16c3b8f87cfa9cb5acee8104820baa389e7ee2bd.

Reverted https://github.com/pytorch/pytorch/pull/135086 on behalf of https://github.com/izaitsevfb due to breaks internal tests, see D62405818 ([comment](https://github.com/pytorch/pytorch/pull/135086#issuecomment-2341889428))
2024-09-10 19:51:16 +00:00
4ca65d3323 [CI] Increase sharding for jobs that are timing out (#135582)
Increase sharding for
* slow grad check
* slow cuda tests slow / linux-focal-cuda12.1-py3.10-gcc9-sm86 / test
* avx

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135582
Approved by: https://github.com/huydhn, https://github.com/malfet
2024-09-10 19:45:13 +00:00
c932b39739 [FSDP2] Added _set_unshard_async_op (#135523)
This PR adds a private API `_set_unshard_async_op` that allows for running pre-forward and pre-backward all-gathers using the `async_op=True` path so that all-gather allocations happen in the default stream to avoid inter-stream fragmentation.

If using this option, forward requires explicit prefetching e.g. via the `unshard(async_op=True)` API for overlap. fp32 -> bf16 casts and the all-gather copy-in will not overlap with compute.

Differential Revision: [D62401551](https://our.internmc.facebook.com/intern/diff/D62401551)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135523
Approved by: https://github.com/weifengpy
2024-09-10 19:28:02 +00:00
1f15973657 [AOTI][Tooling][7/n] Add debug printing support for JIT inductor codegen path as well (#135285)
Summary:
1.  Add the debug printer call to a level lower for triton kernel python wrapper codegen path
2. Add `torch.save()` for jit inductor as well
3. This also fixes the issue introduced in D61949020 (at python wrapper code level for triton kernel not printing)

Test Plan:
```
AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=1  TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+graph, inductor, +schedule, output_code" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_addmm_abi_compatible_cuda
```

Differential Revision: D62272588

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135285
Approved by: https://github.com/chenyang78
2024-09-10 19:24:58 +00:00
fc88ba260f [amdsmi][torch] Update amdsmi API usages (#135504)
Summary: In ROCm 6.2.0 there were API name changes-- we check if the new APIs exist and use them in this diff; see 7b2463abe0 for the changes

Test Plan: CI

Differential Revision: D62325661

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135504
Approved by: https://github.com/eqy, https://github.com/houseroad
2024-09-10 19:15:39 +00:00
bf8d0e3107 [inductor] Enable subprocess parallel compile internally with killswitch (#132467)
Differential Revision: [D60629630](https://our.internmc.facebook.com/intern/diff/D60629630)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132467
Approved by: https://github.com/eellison
2024-09-10 19:05:46 +00:00
3a1239a248 [Profiler] Harden Record Function Kwargs (#135365)
Summary:
In S445839, we had HTA break because of the "stream" parameter that was added to gpu traces. This brought up discussions regarding hardening our post processing of said inputs as to not break JSON schema as well as downstream tools. For this reason, this diff does the following.

1. Only allow int, double, bool and string values to be processed as kwinputs for JSON output. We can handle lists if needed in the future.
2. Make sure that any boolean is lowercase  when a string so that the JSON does not break when parsing it
3. Force stream parameter to be an int

Test Plan: Added unit tests to ensure that the list of requirements above is true for kwargs only.

Differential Revision: D62304843

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135365
Approved by: https://github.com/aaronenyeshi
2024-09-10 18:44:05 +00:00
4f9f1775d8 Fix flaky TestCudaWrapper.test_randint_cuda_cuda_wrapper (#135370)
Summary: This test is flaky when run after `test_dynamic_shapes_persistent_reduction_mixed_x_dim_cuda_cuda_wrapper` because the TestCase sets config options globally in its setUp() that stick around for subsequent tests. For test isolation, we use a contextlib.ExitStack pattern in other tests to patch the config options and restore them in tearDown(). Update all TestCases in `test/inductor/test_combo_kernels.py` to use that pattern.

Test Plan:
```
python test/inductor/test_combo_kernels.py
python test/inductor/test_cuda_cpp_wrapper.py TestCudaWrapper.test_dynamic_shapes_persistent_reduction_mixed_x_dim_cuda_cuda_wrapper TestCudaWrapper.test_randint_cuda_cuda_wrapper
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135370
Approved by: https://github.com/jansel
2024-09-10 18:43:14 +00:00
5e0788befb Migrate remaining jobs to use runner determinator (#134867)
At this point all self-hosted runner jobs should be using the runner determinator to switch between LF and Meta runners. This change updates the remaining jobs that have not yet been migrated over.

Issue: https://lf-pytorch.atlassian.net/browse/PC-25

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134867
Approved by: https://github.com/ZainRizvi
2024-09-10 18:14:00 +00:00
440f8f57af Revert "[fx] Bypass custom __setattr__ in Node.__init__ (#135079)" (#135562)
This reverts commit 66da3b3b2acacb116a9b23e91b24934830eaf6b8.

#135079 breaks internal tests and needs to be reverted. Revert with mergebot doesn't work as this PR is technically part of the stack, but, according to @jansel, it should be possible to revert it individually.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135562
Approved by: https://github.com/jansel, https://github.com/seemethere
2024-09-10 18:07:11 +00:00
e004d539da [Partitioner] Reuse partition to check whether nodes exist (#135317)
The time complexity of find node whether in NodeList is O(n). Reuse partition to speed up due to partition.nodes is hash table and has same elements.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135317
Approved by: https://github.com/ezyang
2024-09-10 17:45:29 +00:00
c4b84a46a9 Add more logging to TunableOp validators (#135396)
Summary: Add more logging to TunableOp validators

Test Plan:
Verified additional logging when loading kernel selections:
```
ROCBLAS_VERSION validation: expect 4.0.0-72e57364-dirty to match 4.0.0-72e57364-dirty
GCN_ARCH_NAME validation: expect gfx942:sramecc+:xnack- to match gfx942:sramecc+:xnack-
HIPBLASLT_VERSION validation: expect 800-a15e4178 to match 800-a15e4178
ROCM_VERSION validation: expect 6.0.0.0-12969-1544e39 to match 6.0.0.0-12969-1544e39
PT_VERSION validation: expect 2.5.0 to match 2.5.0
```

```
[qizixi@devgpu039.atn3 /data/users/qizixi/fbsource/fbcode (f9305317d|remote/master)]$ PYTORCH_TUNABLEOP_VERBOSE=1 buck2 run mode/{opt,amd-gpu} -c fbcode.e
nable_gpu_sections=true //scripts/xdwang/example:fc_llama -- --enable-tuning
File changed: fbcode//hipblas_tuning_pt_llama0.csv
Buck UI: https://www.internalfb.com/buck2/1ed2fac4-743e-49ef-805f-7fb6b9300022
Network: Up: 0B  Down: 0B
Jobs completed: 4189. Time elapsed: 0.2s.
BUILD SUCCEEDED
Enabled tuning
- Run Linear (matmul) 2 x 1280 x 8192, dtype = torch.bfloat16
INFO:2024-09-06 14:38:07 2834864:2835138 CuptiActivityProfiler.cpp:260] HIP versions. Roctracer: 4.1; Runtime: 60032830; Driver: 60032830
INFO:2024-09-06 14:38:07 2834864:2836083 DynoConfigLoader.cpp:61] Setting communication fabric enabled = 0
reading tuning results from hipblas_tuning_pt_llama0.csv
Validator PT_VERSION=2.5.0
Validator ROCM_VERSION=6.0.0.0-12969-1544e39
Validator HIPBLASLT_VERSION=800-a15e4178
Validator GCN_ARCH_NAME=gfx942:sramecc+:xnack-
Validator ROCBLAS_VERSION=4.0.0-72e57364-dirty
ROCBLAS_VERSION validation: expect 4.0.0-72e57364-dirty to match 4.0.0-72e57364-dirty
GCN_ARCH_NAME validation: expect gfx942:sramecc+:xnack- to match gfx942:sramecc+:xnack-
HIPBLASLT_VERSION validation: expect 800-a15e4178 to match 800-a15e4178
ROCM_VERSION validation: expect 6.0.0.0-12969-1544e39 to match 6.0.0.0-12969-1544e39
PT_VERSION validation: expect 2.5.0 to match 2.5.0
Loading results
Avg time: 13.165860176086426 us, Achieved 3.19 TFLOPS, 1598.24 GB/s

- Run Linear (matmul) 2 x 8192 x 1024, dtype = torch.bfloat16
Avg time: 13.230760097503662 us, Achieved 2.54 TFLOPS, 1271.14 GB/s

- Run Linear (matmul) 2 x 7168 x 8192, dtype = torch.bfloat16
Avg time: 26.804399490356445 us, Achieved 8.76 TFLOPS, 4384.90 GB/s

- Run Linear (matmul) 2 x 8192 x 3584, dtype = torch.bfloat16
Avg time: 13.407809734344482 us, Achieved 8.76 TFLOPS, 4384.14 GB/s

2x1280x8192-torch.bfloat16,13.165860176086426,3.18574247630113,1598.237845349412
2x8192x1024-torch.bfloat16,13.230760097503662,2.536092541374924,1271.1420867780075
2x7168x8192-torch.bfloat16,26.804399490356445,8.762778814892096,4384.9040543618985
2x8192x3584-torch.bfloat16,13.407809734344482,8.759112362638383,4384.138585247748
```

Reviewed By: leitian

Differential Revision: D62322830

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135396
Approved by: https://github.com/eqy
2024-09-10 17:20:59 +00:00
cyy
bc1b8f094d Check function declarations of Core ML code (#135467)
Relax the restrictions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135467
Approved by: https://github.com/ezyang
2024-09-10 16:05:22 +00:00
f65a564fa2 [inductor] Flip custom_op_default_layout_constraint (#135239)
By default, Inductor should respect the stride order of input Tensors to
custom operators.

Test Plan:
- new tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135239
Approved by: https://github.com/albanD
ghstack dependencies: #135391
2024-09-10 14:27:43 +00:00
386b313028 Handle KeyError for compiler collective in scalars too (#135385)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135385
Approved by: https://github.com/jansel
2024-09-10 12:33:04 +00:00
6d7cbc20d2 Add dynamo itertools.pairwise support (#135416)
Fixes #133766

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135416
Approved by: https://github.com/XuehaiPan, https://github.com/jansel

Co-authored-by: Xuehai Pan <XuehaiPan@pku.edu.cn>
2024-09-10 11:37:59 +00:00
ca16956b20 [Inductor] Generalize device guard codegen for cpp_wrapper mode. (#134761)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134761
Approved by: https://github.com/jansel, https://github.com/EikanWang
ghstack dependencies: #134693
2024-09-10 10:11:52 +00:00
67735d1ee8 [Inductor] Generalize is_cuda to specific device_type to make cpp_wrapper mode be extensible (#134693)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134693
Approved by: https://github.com/ezyang, https://github.com/EikanWang, https://github.com/jansel
2024-09-10 10:11:13 +00:00
6e13f5eb38 [FlexAttention] Add broadcast support for kv batch dimension (#135505)
This PR adds broadcast support for KV batch dimension.

## Details
Consider Q of shape `[Bq, Hq, Q_LEN, D]`, and K, V of shape `[Bkv, Hkv, KV_LEN, D]`. Prior to this diff, we require `Bq == Bkv`. However, for some use cases, we may have Bkv < Bq. For example, in paged attention, we provide K, V of shape `[1, Hkv, MAX_LEN, D]`, while still providing Q of shape `[Bq, Hq, Q_LEN, D]`. Here, MAX_LEN is the maximal number of tokens supported by paged attention.

This PR relax this requirement to be `Bq == Bkv or (Bq > 1 and Bkv == 0)`. This support covers both flex decoding, flex attention forward and backward.

## Benchmark
GPU: H100

We see negligible (1%~2%) performance change from this PR when `Bq == Bkv`.

```
python benchmarks/transformer/score_mod.py --calculate-bwd
```
### Perf before this PR

**FWD**

| Type    |   Speedup | score_mod     | mask_mod   | dtype          | shape(B,Hq,M,Hkv,N,D)        |
|---------|-----------|---------------|------------|----------------|------------------------------|
| Average |     0.743 |               |            |                |                              |
| Max     |     0.955 | head_bias     | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64)   |
| Min     |     0.548 | relative_bias | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) |

**BWD**

| Type    |   Speedup | score_mod   | mask_mod   | dtype          | shape(B,Hq,M,Hkv,N,D)       |
|---------|-----------|-------------|------------|----------------|-----------------------------|
| Average |     0.834 |             |            |                |                             |
| Max     |     1.261 | head_bias   | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 64)   |
| Min     |     0.456 | None        | causal     | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) |

<details>
<summary> Full performance sweep </summary>

| score_mod     | mask_mod   | dtype          | shape(B,Hq,M,Hkv,N,D)         |   fwd_eager_time |   fwd_compiled_time |   bwd_eager_time |   bwd_compiled_time |   fwd_speedup |   bwd_speedup |
|---------------|------------|----------------|-------------------------------|------------------|---------------------|------------------|---------------------|---------------|---------------|
| None          | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 64)     |           15.264 |              17.184 |          107.040 |             140.800 |         0.888 |         0.760 |
| None          | causal     | torch.bfloat16 | (2, 16, 512, 16, 512, 64)     |           15.840 |              19.744 |          112.576 |             140.064 |         0.802 |         0.804 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 64)     |           15.232 |              17.344 |           87.744 |             142.496 |         0.878 |         0.616 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 64)     |           15.264 |              17.184 |          108.192 |             143.328 |         0.888 |         0.755 |
| None          | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 128)    |           19.904 |              22.400 |          106.432 |             136.512 |         0.889 |         0.780 |
| None          | causal     | torch.bfloat16 | (2, 16, 512, 16, 512, 128)    |           19.424 |              26.752 |           91.712 |             106.688 |         0.726 |         0.860 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 128)    |           19.808 |              22.432 |           89.024 |             101.920 |         0.883 |         0.873 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 128)    |           19.840 |              22.272 |           88.896 |             102.592 |         0.891 |         0.867 |
| None          | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64)   |           30.240 |              32.416 |          116.768 |             112.256 |         0.933 |         1.040 |
| None          | causal     | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64)   |           29.536 |              37.024 |          113.664 |             102.688 |         0.798 |         1.107 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64)   |           30.656 |              32.800 |          116.992 |             127.008 |         0.935 |         0.921 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64)   |           30.592 |              32.480 |          116.928 |             112.160 |         0.942 |         1.043 |
| None          | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128)  |           40.448 |              61.920 |          198.656 |             204.512 |         0.653 |         0.971 |
| None          | causal     | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128)  |           37.760 |              62.528 |          189.536 |             170.624 |         0.604 |         1.111 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128)  |           40.896 |              62.368 |          198.304 |             205.824 |         0.656 |         0.963 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128)  |           40.448 |              61.952 |          198.432 |             203.648 |         0.653 |         0.974 |
| None          | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64)   |          318.528 |             355.904 |          947.232 |            1162.496 |         0.895 |         0.815 |
| None          | causal     | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64)   |          199.776 |             252.128 |          677.792 |             813.184 |         0.792 |         0.834 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64)   |          316.512 |             363.328 |          947.712 |            1361.984 |         0.871 |         0.696 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64)   |          317.984 |             356.864 |          947.264 |            1165.024 |         0.891 |         0.813 |
| None          | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128)  |          446.656 |             734.656 |         1664.288 |            2172.960 |         0.608 |         0.766 |
| None          | causal     | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128)  |          278.688 |             467.648 |         1182.624 |            1339.296 |         0.596 |         0.883 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128)  |          447.872 |             744.096 |         1662.944 |            2196.544 |         0.602 |         0.757 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128)  |          448.128 |             732.928 |         1663.072 |            2156.800 |         0.611 |         0.771 |
| None          | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 64)      |           15.648 |              16.640 |          107.520 |             143.008 |         0.940 |         0.752 |
| None          | causal     | torch.bfloat16 | (2, 16, 512, 2, 512, 64)      |           15.776 |              18.240 |          129.056 |             141.920 |         0.865 |         0.909 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 64)      |           15.168 |              16.640 |          103.616 |             139.648 |         0.912 |         0.742 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 64)      |           15.616 |              16.640 |          128.608 |             164.448 |         0.938 |         0.782 |
| None          | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 128)     |           19.776 |              21.952 |          125.344 |             170.304 |         0.901 |         0.736 |
| None          | causal     | torch.bfloat16 | (2, 16, 512, 2, 512, 128)     |           19.776 |              23.712 |          104.288 |             196.896 |         0.834 |         0.530 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 128)     |           19.072 |              21.952 |          102.080 |             177.056 |         0.869 |         0.577 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 128)     |           19.648 |              21.920 |          109.920 |             170.848 |         0.896 |         0.643 |
| None          | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64)    |           30.464 |              31.936 |          127.808 |             228.832 |         0.954 |         0.559 |
| None          | causal     | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64)    |           29.472 |              33.856 |          113.152 |             215.072 |         0.871 |         0.526 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64)    |           30.496 |              32.160 |          116.576 |             231.744 |         0.948 |         0.503 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64)    |           30.464 |              31.904 |          116.320 |             229.824 |         0.955 |         0.506 |
| None          | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128)   |           40.480 |              61.440 |          176.448 |             345.312 |         0.659 |         0.511 |
| None          | causal     | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128)   |           38.304 |              59.424 |          169.312 |             371.360 |         0.645 |         0.456 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128)   |           40.960 |              61.760 |          176.512 |             358.912 |         0.663 |         0.492 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128)   |           40.352 |              61.696 |          176.512 |             344.928 |         0.654 |         0.512 |
| None          | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64)    |          316.224 |             357.728 |          905.728 |            1668.448 |         0.884 |         0.543 |
| None          | causal     | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64)    |          199.904 |             248.416 |          636.544 |            1109.088 |         0.805 |         0.574 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64)    |          314.880 |             363.616 |          906.304 |            1658.176 |         0.866 |         0.547 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64)    |          316.160 |             354.368 |          906.080 |            1649.024 |         0.892 |         0.549 |
| None          | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128)   |          446.912 |             739.840 |         1555.808 |            2521.952 |         0.604 |         0.617 |
| None          | causal     | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128)   |          279.776 |             463.904 |         1068.928 |            1849.888 |         0.603 |         0.578 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128)   |          446.080 |             748.960 |         1553.504 |            2629.888 |         0.596 |         0.591 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128)   |          446.208 |             740.608 |         1558.880 |            2524.960 |         0.602 |         0.617 |
| None          | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 64)     |           33.568 |              41.280 |          170.016 |             147.584 |         0.813 |         1.152 |
| None          | causal     | torch.bfloat16 | (8, 16, 512, 16, 512, 64)     |           30.688 |              43.040 |          159.552 |             146.720 |         0.713 |         1.087 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 64)     |           34.112 |              41.504 |          170.112 |             152.672 |         0.822 |         1.114 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 64)     |           34.240 |              41.152 |          170.272 |             134.976 |         0.832 |         1.261 |
| None          | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 128)    |           48.672 |              76.416 |          295.296 |             263.648 |         0.637 |         1.120 |
| None          | causal     | torch.bfloat16 | (8, 16, 512, 16, 512, 128)    |           45.088 |              72.576 |          281.920 |             237.664 |         0.621 |         1.186 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 128)    |           48.032 |              76.672 |          295.520 |             265.248 |         0.626 |         1.114 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 128)    |           48.096 |              76.096 |          295.456 |             262.112 |         0.632 |         1.127 |
| None          | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64)   |           93.920 |             111.232 |          401.568 |             382.944 |         0.844 |         1.049 |
| None          | causal     | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64)   |           68.192 |              95.232 |          338.752 |             326.816 |         0.716 |         1.037 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64)   |           93.984 |             111.840 |          401.856 |             444.224 |         0.840 |         0.905 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64)   |           94.176 |             110.496 |          401.600 |             383.136 |         0.852 |         1.048 |
| None          | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128)  |          131.488 |             227.040 |          727.424 |             739.712 |         0.579 |         0.983 |
| None          | causal     | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128)  |           95.616 |             169.760 |          616.864 |             574.112 |         0.563 |         1.074 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128)  |          131.680 |             228.672 |          727.616 |             746.048 |         0.576 |         0.975 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128)  |          131.104 |             225.696 |          727.904 |             735.392 |         0.581 |         0.990 |
| None          | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64)   |         1227.296 |            1386.656 |         3720.192 |            4539.904 |         0.885 |         0.819 |
| None          | causal     | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64)   |          691.360 |             831.712 |         2515.872 |            3067.808 |         0.831 |         0.820 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64)   |         1228.192 |            1403.136 |         3715.520 |            5309.280 |         0.875 |         0.700 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64)   |         1229.024 |            1384.992 |         3715.904 |            4550.368 |         0.887 |         0.817 |
| None          | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128)  |         1784.832 |            2865.888 |         6539.840 |            8460.224 |         0.623 |         0.773 |
| None          | causal     | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128)  |         1017.408 |            1660.480 |         4369.824 |            5056.992 |         0.613 |         0.864 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128)  |         1792.448 |            2904.864 |         6546.080 |            8537.024 |         0.617 |         0.767 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128)  |         1795.552 |            2856.864 |         6544.672 |            8400.160 |         0.629 |         0.779 |
| None          | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 64)      |           34.240 |              38.880 |          148.832 |             179.936 |         0.881 |         0.827 |
| None          | causal     | torch.bfloat16 | (8, 16, 512, 2, 512, 64)      |           31.168 |              38.080 |          138.528 |             167.552 |         0.818 |         0.827 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 64)      |           34.240 |              39.168 |          148.512 |             181.248 |         0.874 |         0.819 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 64)      |           34.240 |              38.784 |          148.864 |             180.224 |         0.883 |         0.826 |
| None          | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 128)     |           48.832 |              76.352 |          253.632 |             295.968 |         0.640 |         0.857 |
| None          | causal     | torch.bfloat16 | (8, 16, 512, 2, 512, 128)     |           45.760 |              65.792 |          239.040 |             290.752 |         0.696 |         0.822 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 128)     |           48.768 |              76.576 |          253.312 |             304.032 |         0.637 |         0.833 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 128)     |           48.768 |              76.192 |          253.600 |             296.096 |         0.640 |         0.856 |
| None          | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64)    |           93.728 |             109.728 |          357.696 |             498.912 |         0.854 |         0.717 |
| None          | causal     | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64)    |           68.704 |              92.288 |          295.616 |             386.240 |         0.744 |         0.765 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64)    |           93.632 |             111.392 |          357.408 |             512.448 |         0.841 |         0.697 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64)    |           93.280 |             109.952 |          357.696 |             501.440 |         0.848 |         0.713 |
| None          | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128)   |          131.392 |             230.496 |          612.224 |             807.552 |         0.570 |         0.758 |
| None          | causal     | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128)   |           96.512 |             165.184 |          502.624 |             672.384 |         0.584 |         0.748 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128)   |          131.360 |             232.608 |          612.064 |             832.320 |         0.565 |         0.735 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128)   |          131.008 |             230.528 |          612.640 |             804.320 |         0.568 |         0.762 |
| None          | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64)    |         1227.968 |            1377.408 |         3477.920 |            5324.384 |         0.892 |         0.653 |
| None          | causal     | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64)    |          695.264 |             824.544 |         2268.224 |            3210.208 |         0.843 |         0.707 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64)    |         1228.640 |            1404.576 |         3476.832 |            5463.456 |         0.875 |         0.636 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64)    |         1228.416 |            1378.752 |         3478.048 |            5367.712 |         0.891 |         0.648 |
| None          | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128)   |         1788.736 |            2867.712 |         6039.520 |            8616.256 |         0.624 |         0.701 |
| None          | causal     | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128)   |         1021.952 |            1653.824 |         3866.208 |            5306.848 |         0.618 |         0.729 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128)   |         1786.752 |            2896.352 |         6044.128 |            8871.360 |         0.617 |         0.681 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128)   |         1786.080 |            2868.672 |         6040.160 |            8550.144 |         0.623 |         0.706 |
| None          | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 64)    |           57.504 |              71.552 |          312.768 |             255.040 |         0.804 |         1.226 |
| None          | causal     | torch.bfloat16 | (16, 16, 512, 16, 512, 64)    |           49.472 |              71.104 |          285.696 |             243.520 |         0.696 |         1.173 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 64)    |           58.112 |              72.896 |          312.768 |             288.256 |         0.797 |         1.085 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 64)    |           57.952 |              71.680 |          312.768 |             255.552 |         0.808 |         1.224 |
| None          | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 128)   |           82.336 |             144.256 |          580.128 |             500.160 |         0.571 |         1.160 |
| None          | causal     | torch.bfloat16 | (16, 16, 512, 16, 512, 128)   |           76.160 |             123.712 |          552.544 |             447.648 |         0.616 |         1.234 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 128)   |           82.400 |             145.184 |          580.032 |             504.032 |         0.568 |         1.151 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 128)   |           82.368 |             143.904 |          580.192 |             499.936 |         0.572 |         1.161 |
| None          | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64)  |          177.216 |             209.568 |          787.872 |             747.712 |         0.846 |         1.054 |
| None          | causal     | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64)  |          121.984 |             168.256 |          651.968 |             628.256 |         0.725 |         1.038 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64)  |          177.088 |             211.488 |          788.320 |             864.352 |         0.837 |         0.912 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64)  |          177.440 |             208.576 |          787.424 |             749.120 |         0.851 |         1.051 |
| None          | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) |          249.472 |             441.376 |         1405.440 |            1431.648 |         0.565 |         0.982 |
| None          | causal     | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) |          172.960 |             312.064 |         1172.064 |            1096.448 |         0.554 |         1.069 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) |          249.632 |             446.336 |         1405.408 |            1448.480 |         0.559 |         0.970 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) |          250.944 |             440.128 |         1406.624 |            1421.952 |         0.570 |         0.989 |
| None          | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64)  |         2418.720 |            2747.936 |         7330.432 |            9023.712 |         0.880 |         0.812 |
| None          | causal     | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64)  |         1353.696 |            1608.480 |         4941.696 |            6078.752 |         0.842 |         0.813 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64)  |         2427.456 |            2746.816 |         7329.792 |           10539.968 |         0.884 |         0.695 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64)  |         2426.688 |            2763.168 |         7336.256 |            9057.536 |         0.878 |         0.810 |
| None          | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) |         3554.240 |            5634.400 |        12919.872 |           16843.489 |         0.631 |         0.767 |
| None          | causal     | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) |         2003.648 |            3250.784 |         8610.144 |           10015.424 |         0.616 |         0.860 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) |         3582.080 |            5710.944 |        12923.328 |           17011.871 |         0.627 |         0.760 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) |         3581.920 |            5618.144 |        12934.528 |           16745.888 |         0.638 |         0.772 |
| None          | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 64)     |           57.120 |              71.232 |          269.760 |             295.680 |         0.802 |         0.912 |
| None          | causal     | torch.bfloat16 | (16, 16, 512, 2, 512, 64)     |           49.408 |              65.312 |          242.304 |             253.952 |         0.756 |         0.954 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 64)     |           57.504 |              72.544 |          269.632 |             298.976 |         0.793 |         0.902 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 64)     |           57.760 |              71.040 |          269.600 |             296.640 |         0.813 |         0.909 |
| None          | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 128)    |           82.336 |             147.168 |          466.080 |             487.456 |         0.559 |         0.956 |
| None          | causal     | torch.bfloat16 | (16, 16, 512, 2, 512, 128)    |           76.704 |             115.040 |          435.392 |             453.248 |         0.667 |         0.961 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 128)    |           81.856 |             147.424 |          465.920 |             499.552 |         0.555 |         0.933 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 128)    |           81.760 |             146.656 |          466.176 |             485.984 |         0.557 |         0.959 |
| None          | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64)   |          176.608 |             206.976 |          678.080 |             866.976 |         0.853 |         0.782 |
| None          | causal     | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64)   |          121.664 |             164.768 |          538.240 |             636.160 |         0.738 |         0.846 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64)   |          176.608 |             209.664 |          677.696 |             883.424 |         0.842 |         0.767 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64)   |          177.440 |             207.840 |          677.248 |             868.288 |         0.854 |         0.780 |
| None          | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128)  |          250.272 |             449.536 |         1163.424 |            1420.832 |         0.557 |         0.819 |
| None          | causal     | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128)  |          173.472 |             305.376 |          929.408 |            1104.544 |         0.568 |         0.841 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128)  |          249.376 |             454.976 |         1163.648 |            1455.296 |         0.548 |         0.800 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128)  |          250.368 |             450.144 |         1163.520 |            1409.984 |         0.556 |         0.825 |
| None          | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64)   |         2416.576 |            2726.208 |         6835.520 |           10442.784 |         0.886 |         0.655 |
| None          | causal     | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64)   |         1357.440 |            1590.752 |         4433.664 |            5975.296 |         0.853 |         0.742 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64)   |         2427.360 |            2747.040 |         6853.056 |           10670.784 |         0.884 |         0.642 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64)   |         2441.120 |            2718.944 |         6836.640 |           10433.792 |         0.898 |         0.655 |
| None          | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128)  |         3555.392 |            5620.960 |        11944.000 |           16504.801 |         0.633 |         0.724 |
| None          | causal     | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128)  |         2010.848 |            3241.152 |         7636.064 |            9870.464 |         0.620 |         0.774 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128)  |         3557.440 |            5688.352 |        11935.744 |           17090.496 |         0.625 |         0.698 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128)  |         3562.720 |            5630.432 |        11939.168 |           16392.033 |         0.633 |         0.728 |

</details>

### Perf after this PR

**FWD**

| Type    |   Speedup | score_mod     | mask_mod   | dtype          | shape(B,Hq,M,Hkv,N,D)      |
|---------|-----------|---------------|------------|----------------|----------------------------|
| Average |     0.776 |               |            |                |                            |
| Max     |     1.006 | None          | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) |
| Min     |     0.566 | relative_bias | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 128) |

**BWD**

| Type    |   Speedup | score_mod   | mask_mod   | dtype          | shape(B,Hq,M,Hkv,N,D)       |
|---------|-----------|-------------|------------|----------------|-----------------------------|
| Average |     0.817 |             |            |                |                             |
| Max     |     1.150 | None        | causal     | torch.bfloat16 | (16, 16, 512, 16, 512, 128) |
| Min     |     0.454 | None        | causal     | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) |

<details>
<summary> Full performance sweep </summary>

| score_mod     | mask_mod   | dtype          | shape(B,Hq,M,Hkv,N,D)         |   fwd_eager_time |   fwd_compiled_time |   bwd_eager_time |   bwd_compiled_time |   fwd_speedup |   bwd_speedup |
|---------------|------------|----------------|-------------------------------|------------------|---------------------|------------------|---------------------|---------------|---------------|
| None          | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 64)     |           15.680 |              17.056 |           64.544 |              73.376 |         0.919 |         0.880 |
| None          | causal     | torch.bfloat16 | (2, 16, 512, 16, 512, 64)     |           15.712 |              19.872 |           65.408 |              72.864 |         0.791 |         0.898 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 64)     |           16.160 |              17.280 |           64.896 |              73.888 |         0.935 |         0.878 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 64)     |           16.192 |              17.120 |           64.896 |              75.424 |         0.946 |         0.860 |
| None          | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 128)    |           19.648 |              22.496 |           89.184 |              82.592 |         0.873 |         1.080 |
| None          | causal     | torch.bfloat16 | (2, 16, 512, 16, 512, 128)    |           20.320 |              26.816 |           91.264 |              82.880 |         0.758 |         1.101 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 128)    |           20.096 |              22.528 |           89.184 |              83.776 |         0.892 |         1.065 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 128)    |           19.680 |              22.432 |           89.184 |             120.096 |         0.877 |         0.743 |
| None          | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64)   |           32.384 |              32.512 |          119.232 |             128.960 |         0.996 |         0.925 |
| None          | causal     | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64)   |           30.176 |              37.248 |          113.664 |             119.520 |         0.810 |         0.951 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64)   |           32.512 |              32.928 |          119.264 |             131.456 |         0.987 |         0.907 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64)   |           32.448 |              32.704 |          119.200 |             128.352 |         0.992 |         0.929 |
| None          | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128)  |           41.952 |              62.176 |          199.040 |             214.304 |         0.675 |         0.929 |
| None          | causal     | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128)  |           39.744 |              62.880 |          189.504 |             179.968 |         0.632 |         1.053 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128)  |           41.472 |              62.784 |          199.136 |             217.664 |         0.661 |         0.915 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128)  |           42.048 |              61.952 |          199.168 |             214.496 |         0.679 |         0.929 |
| None          | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64)   |          341.184 |             357.632 |          980.256 |            1328.896 |         0.954 |         0.738 |
| None          | causal     | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64)   |          212.576 |             252.960 |          673.888 |             824.864 |         0.840 |         0.817 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64)   |          340.000 |             363.296 |          980.768 |            1375.808 |         0.936 |         0.713 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64)   |          340.768 |             356.832 |          980.960 |            1326.272 |         0.955 |         0.740 |
| None          | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128)  |          459.392 |             737.120 |         1678.240 |            2205.248 |         0.623 |         0.761 |
| None          | causal     | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128)  |          292.672 |             468.096 |         1178.016 |            1371.584 |         0.625 |         0.859 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128)  |          462.144 |             745.312 |         1680.000 |            2252.512 |         0.620 |         0.746 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128)  |          462.112 |             736.576 |         1679.008 |            2216.480 |         0.627 |         0.758 |
| None          | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 64)      |           16.064 |              16.704 |          105.120 |             120.768 |         0.962 |         0.870 |
| None          | causal     | torch.bfloat16 | (2, 16, 512, 2, 512, 64)      |           15.552 |              18.144 |          107.136 |             121.696 |         0.857 |         0.880 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 64)      |           16.096 |              16.768 |          102.688 |             120.864 |         0.960 |         0.850 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 64)      |           16.032 |              16.576 |          104.736 |             124.672 |         0.967 |         0.840 |
| None          | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 128)     |           19.392 |              21.952 |          104.736 |             174.656 |         0.883 |         0.600 |
| None          | causal     | torch.bfloat16 | (2, 16, 512, 2, 512, 128)     |           20.128 |              23.712 |          105.216 |             199.008 |         0.849 |         0.529 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 128)     |           19.904 |              21.888 |          103.744 |             179.520 |         0.909 |         0.578 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 128)     |           19.968 |              21.952 |          104.640 |             177.312 |         0.910 |         0.590 |
| None          | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64)    |           32.096 |              31.904 |          118.720 |             231.968 |         1.006 |         0.512 |
| None          | causal     | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64)    |           30.528 |              33.952 |          112.480 |             218.304 |         0.899 |         0.515 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64)    |           32.160 |              32.224 |          118.752 |             237.312 |         0.998 |         0.500 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64)    |           32.128 |              32.032 |          118.240 |             233.120 |         1.003 |         0.507 |
| None          | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128)   |           41.312 |              61.280 |          177.408 |             350.688 |         0.674 |         0.506 |
| None          | causal     | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128)   |           39.552 |              59.360 |          168.832 |             371.488 |         0.666 |         0.454 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128)   |           41.984 |              61.696 |          177.376 |             360.416 |         0.680 |         0.492 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128)   |           41.312 |              61.760 |          177.184 |             355.744 |         0.669 |         0.498 |
| None          | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64)    |          339.744 |             357.888 |          939.712 |            1665.376 |         0.949 |         0.564 |
| None          | causal     | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64)    |          212.608 |             248.832 |          633.280 |            1122.848 |         0.854 |         0.564 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64)    |          339.712 |             363.232 |          940.448 |            1689.440 |         0.935 |         0.557 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64)    |          341.056 |             355.264 |          940.128 |            1641.152 |         0.960 |         0.573 |
| None          | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128)   |          460.736 |             741.024 |         1569.824 |            2559.552 |         0.622 |         0.613 |
| None          | causal     | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128)   |          293.856 |             464.192 |         1066.240 |            1840.416 |         0.633 |         0.579 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128)   |          460.704 |             753.152 |         1570.112 |            2641.088 |         0.612 |         0.594 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128)   |          460.832 |             745.536 |         1570.144 |            2602.560 |         0.618 |         0.603 |
| None          | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 64)     |           35.680 |              41.280 |          171.840 |             158.176 |         0.864 |         1.086 |
| None          | causal     | torch.bfloat16 | (8, 16, 512, 16, 512, 64)     |           31.360 |              42.976 |          158.912 |             139.264 |         0.730 |         1.141 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 64)     |           35.168 |              41.600 |          171.648 |             161.344 |         0.845 |         1.064 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 64)     |           35.136 |              41.152 |          171.808 |             158.336 |         0.854 |         1.085 |
| None          | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 128)    |           48.832 |              76.384 |          295.680 |             277.696 |         0.639 |         1.065 |
| None          | causal     | torch.bfloat16 | (8, 16, 512, 16, 512, 128)    |           45.632 |              72.512 |          281.760 |             250.752 |         0.629 |         1.124 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 128)    |           49.504 |              76.608 |          295.584 |             279.712 |         0.646 |         1.057 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 128)    |           48.864 |              75.904 |          295.456 |             277.568 |         0.644 |         1.064 |
| None          | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64)   |           99.392 |             111.232 |          408.640 |             442.656 |         0.894 |         0.923 |
| None          | causal     | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64)   |           71.392 |              95.168 |          338.784 |             341.760 |         0.750 |         0.991 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64)   |           99.808 |             112.256 |          408.608 |             456.160 |         0.889 |         0.896 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64)   |          100.032 |             110.816 |          408.512 |             444.192 |         0.903 |         0.920 |
| None          | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128)  |          135.040 |             226.112 |          726.880 |             774.176 |         0.597 |         0.939 |
| None          | causal     | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128)  |           99.904 |             169.696 |          616.448 |             607.104 |         0.589 |         1.015 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128)  |          135.488 |             228.384 |          727.776 |             782.368 |         0.593 |         0.930 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128)  |          135.744 |             225.664 |          728.000 |             773.600 |         0.602 |         0.941 |
| None          | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64)   |         1324.192 |            1387.808 |         3866.944 |            5217.184 |         0.954 |         0.741 |
| None          | causal     | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64)   |          738.464 |             832.608 |         2507.392 |            3146.688 |         0.887 |         0.797 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64)   |         1326.016 |            1404.256 |         3867.872 |            5382.624 |         0.944 |         0.719 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64)   |         1326.144 |            1386.688 |         3867.552 |            5203.264 |         0.956 |         0.743 |
| None          | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128)  |         1847.488 |            2866.336 |         6612.704 |            8597.696 |         0.645 |         0.769 |
| None          | causal     | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128)  |         1066.592 |            1660.640 |         4357.696 |            5174.016 |         0.642 |         0.842 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128)  |         1850.464 |            2905.408 |         6616.928 |            8793.280 |         0.637 |         0.752 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128)  |         1848.896 |            2834.720 |         6623.872 |            8637.920 |         0.652 |         0.767 |
| None          | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 64)      |           36.384 |              38.656 |          150.336 |             182.624 |         0.941 |         0.823 |
| None          | causal     | torch.bfloat16 | (8, 16, 512, 2, 512, 64)      |           31.360 |              38.112 |          137.664 |             171.840 |         0.823 |         0.801 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 64)      |           36.608 |              39.040 |          150.528 |             183.872 |         0.938 |         0.819 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 64)      |           36.064 |              38.656 |          150.560 |             183.520 |         0.933 |         0.820 |
| None          | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 128)     |           49.344 |              76.352 |          253.920 |             301.440 |         0.646 |         0.842 |
| None          | causal     | torch.bfloat16 | (8, 16, 512, 2, 512, 128)     |           46.720 |              65.824 |          239.424 |             296.384 |         0.710 |         0.808 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 128)     |           49.248 |              76.416 |          253.728 |             307.808 |         0.644 |         0.824 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 128)     |           49.376 |              76.288 |          253.728 |             304.736 |         0.647 |         0.833 |
| None          | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64)    |           99.264 |             110.144 |          364.960 |             503.072 |         0.901 |         0.725 |
| None          | causal     | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64)    |           71.136 |              92.384 |          294.432 |             393.056 |         0.770 |         0.749 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64)    |           99.200 |             111.360 |          365.152 |             512.640 |         0.891 |         0.712 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64)    |           99.264 |             110.240 |          365.088 |             504.224 |         0.900 |         0.724 |
| None          | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128)   |          135.680 |             230.336 |          613.472 |             816.896 |         0.589 |         0.751 |
| None          | causal     | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128)   |          100.256 |             165.088 |          502.144 |             676.480 |         0.607 |         0.742 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128)   |          135.008 |             232.480 |          613.184 |             836.672 |         0.581 |         0.733 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128)   |          135.232 |             230.624 |          613.536 |             827.136 |         0.586 |         0.742 |
| None          | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64)    |         1324.064 |            1378.688 |         3631.808 |            5308.384 |         0.960 |         0.684 |
| None          | causal     | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64)    |          731.776 |             826.688 |         2263.168 |            3241.344 |         0.885 |         0.698 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64)    |         1316.128 |            1403.200 |         3625.088 |            5550.688 |         0.938 |         0.653 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64)    |         1311.904 |            1378.880 |         3616.320 |            5353.696 |         0.951 |         0.675 |
| None          | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128)   |         1837.856 |            2887.392 |         6121.632 |            8586.656 |         0.637 |         0.713 |
| None          | causal     | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128)   |         1066.976 |            1654.368 |         3843.136 |            5291.040 |         0.645 |         0.726 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128)   |         1854.208 |            2896.832 |         6130.112 |            8745.984 |         0.640 |         0.701 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128)   |         1860.512 |            2889.344 |         6135.648 |            8750.592 |         0.644 |         0.701 |
| None          | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 64)    |           60.640 |              71.552 |          315.968 |             296.512 |         0.847 |         1.066 |
| None          | causal     | torch.bfloat16 | (16, 16, 512, 16, 512, 64)    |           50.784 |              71.040 |          284.288 |             258.880 |         0.715 |         1.098 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 64)    |           61.312 |              72.704 |          315.680 |             302.016 |         0.843 |         1.045 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 64)    |           60.800 |              71.776 |          316.320 |             297.152 |         0.847 |         1.065 |
| None          | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 128)   |           84.576 |             144.416 |          580.576 |             535.936 |         0.586 |         1.083 |
| None          | causal     | torch.bfloat16 | (16, 16, 512, 16, 512, 128)   |           76.064 |             123.648 |          553.344 |             481.376 |         0.615 |         1.150 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 128)   |           84.160 |             145.248 |          581.024 |             540.000 |         0.579 |         1.076 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 128)   |           84.512 |             143.552 |          581.088 |             535.776 |         0.589 |         1.085 |
| None          | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64)  |          189.152 |             209.408 |          798.400 |             868.704 |         0.903 |         0.919 |
| None          | causal     | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64)  |          127.552 |             168.800 |          650.816 |             663.328 |         0.756 |         0.981 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64)  |          189.376 |             211.360 |          798.080 |             895.552 |         0.896 |         0.891 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64)  |          189.440 |             208.576 |          797.888 |             873.152 |         0.908 |         0.914 |
| None          | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) |          257.536 |             441.760 |         1408.960 |            1514.720 |         0.583 |         0.930 |
| None          | causal     | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) |          179.328 |             312.096 |         1170.368 |            1177.472 |         0.575 |         0.994 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) |          259.264 |             446.944 |         1408.768 |            1530.400 |         0.580 |         0.921 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) |          258.080 |             440.480 |         1408.864 |            1514.144 |         0.586 |         0.930 |
| None          | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64)  |         2595.808 |            2771.456 |         7616.704 |           10405.248 |         0.937 |         0.732 |
| None          | causal     | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64)  |         1435.744 |            1610.336 |         4927.520 |            6220.000 |         0.892 |         0.792 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64)  |         2595.264 |            2745.056 |         7611.232 |           10631.392 |         0.945 |         0.716 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64)  |         2576.256 |            2735.456 |         7626.400 |           10346.976 |         0.942 |         0.737 |
| None          | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) |         3679.744 |            5634.816 |        13077.056 |           17182.528 |         0.653 |         0.761 |
| None          | causal     | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) |         2099.360 |            3250.176 |         8589.664 |           10236.672 |         0.646 |         0.839 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) |         3676.800 |            5716.288 |        13073.088 |           17311.071 |         0.643 |         0.755 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) |         3679.136 |            5570.496 |        13070.720 |           17192.863 |         0.660 |         0.760 |
| None          | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 64)     |           61.600 |              71.008 |          272.320 |             300.000 |         0.868 |         0.908 |
| None          | causal     | torch.bfloat16 | (16, 16, 512, 2, 512, 64)     |           50.176 |              65.344 |          241.568 |             258.912 |         0.768 |         0.933 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 64)     |           61.120 |              72.512 |          272.672 |             305.408 |         0.843 |         0.893 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 64)     |           61.248 |              71.136 |          272.640 |             301.120 |         0.861 |         0.905 |
| None          | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 128)    |           83.872 |             146.784 |          466.912 |             496.832 |         0.571 |         0.940 |
| None          | causal     | torch.bfloat16 | (16, 16, 512, 2, 512, 128)    |           76.704 |             115.072 |          435.584 |             462.112 |         0.667 |         0.943 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 128)    |           83.392 |             147.392 |          466.656 |             504.448 |         0.566 |         0.925 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 128)    |           83.360 |             146.688 |          466.656 |             499.040 |         0.568 |         0.935 |
| None          | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64)   |          189.024 |             207.584 |          684.768 |             873.568 |         0.911 |         0.784 |
| None          | causal     | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64)   |          126.944 |             164.288 |          536.192 |             645.984 |         0.773 |         0.830 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64)   |          188.768 |             209.760 |          684.096 |             897.504 |         0.900 |         0.762 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64)   |          189.408 |             207.776 |          685.024 |             876.384 |         0.912 |         0.782 |
| None          | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128)  |          259.168 |             449.536 |         1167.936 |            1433.280 |         0.577 |         0.815 |
| None          | causal     | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128)  |          180.000 |             305.312 |          928.000 |            1113.920 |         0.590 |         0.833 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128)  |          258.464 |             455.136 |         1167.808 |            1462.848 |         0.568 |         0.798 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128)  |          257.824 |             450.208 |         1167.744 |            1448.000 |         0.573 |         0.806 |
| None          | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64)   |         2598.368 |            2729.120 |         7134.400 |           10381.632 |         0.952 |         0.687 |
| None          | causal     | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64)   |         1435.456 |            1591.040 |         4424.768 |            6035.808 |         0.902 |         0.733 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64)   |         2594.752 |            2725.952 |         7128.384 |           10822.496 |         0.952 |         0.659 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64)   |         2597.888 |            2716.960 |         7101.568 |           10385.440 |         0.956 |         0.684 |
| None          | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128)  |         3647.648 |            5581.632 |        12089.952 |           16667.233 |         0.654 |         0.725 |
| None          | causal     | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128)  |         2093.952 |            3241.440 |         7579.392 |            9847.936 |         0.646 |         0.770 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128)  |         3650.528 |            5650.688 |        12105.568 |           16963.680 |         0.646 |         0.714 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128)  |         3680.064 |            5585.312 |        12117.504 |           16935.040 |         0.659 |         0.716 |

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135505
Approved by: https://github.com/Chillee
2024-09-10 09:30:02 +00:00
23b1486185 [MPS] Allow nan mean reduction in nll_loss (#135434)
This PR allows results from `nn_loss` to be `nan`, which is the same behavior as with CUDA and CPU https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162.

Fixes #134431

Ref #64572 #119108
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135434
Approved by: https://github.com/malfet
2024-09-10 08:37:59 +00:00
9902b349cb [Inductor] Make static_input_idxs a set for faster lookup (#135314)
`static_input_idxs` is only used for lookups. With large models, this is a large list. This takes over a millisecond in some cases.

Profile before change:
<img width="824" alt="image" src="https://github.com/user-attachments/assets/002a0775-fd2f-4d27-8cf2-812b502d7d5e">

Profile after change: gaps are smaller, 1ms speedup before launching the cuda graph
<img width="794" alt="image" src="https://github.com/user-attachments/assets/12a0a0b9-2cc1-4d53-ac87-9bd5140a46f5">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135314
Approved by: https://github.com/oulgen
2024-09-10 07:27:55 +00:00
5a9ac83e94 Fix doc (#135551)
Differential Revision: [D62412667](https://our.internmc.facebook.com/intern/diff/D62412667/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135551
Approved by: https://github.com/yushangdi
ghstack dependencies: #135549
2024-09-10 07:18:44 +00:00
1adf28a5c0 [inductor] print triton float64 constants correctly (#135260)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135260
Approved by: https://github.com/jansel
2024-09-10 07:05:02 +00:00
c18052da0e Add some minor doc improvement and ban using training IR for unflattener (#135549)
Title

Differential Revision: [D62412490](https://our.internmc.facebook.com/intern/diff/D62412490/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135549
Approved by: https://github.com/yushangdi
2024-09-10 06:48:42 +00:00
c0d2f991b1 Increase TRITON_MAX_BLOCK['X'] (#135181)
Fixes #135028

As title, increase `TRITON_MAX_BLOCK['X']` to 4096 and fix an error, thanks to @Chillee: https://github.com/pytorch/pytorch/pull/133300/files#r1744706189

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135181
Approved by: https://github.com/jansel
2024-09-10 05:54:37 +00:00
e889252493 Implementation of scan (#134102)
This operation is supposed to be the pendant to the `associative_scan`, but can operate with non-associative functions.

@ydwu4

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134102
Approved by: https://github.com/ydwu4
2024-09-10 04:51:16 +00:00
6546c6186d do not raise when flatten_fn_with_keys not found when suggesting fixes (#135518)
Test Plan: added test

Differential Revision: D62395371

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135518
Approved by: https://github.com/zhxchen17
2024-09-10 03:47:36 +00:00
1d9fefff19 [DCP] Fixes the stateless optimizer issue of distributed state_dict (#135535)
Some optimizers don't have states that can cause get_state_dict/set_state_dict behave incorrectly. This PR fixes the issues.

fixes: https://github.com/pytorch/pytorch/issues/133415

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135535
Approved by: https://github.com/wz337
2024-09-10 03:10:00 +00:00
7ec17b49cf Fix dynamo benchmark skip logic for cpu device (#135193)
Fixes #132380, adjust torchbench and huggingface skip models list, then we can remove `--no-skip` when running benchmarks on 3 suites.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135193
Approved by: https://github.com/chuanqi129, https://github.com/jansel
2024-09-10 03:02:19 +00:00
146921007a [inductor] [cpp] fix the input contiguous check in max-autotune (#134982)
## Description
Fixes the FP32 accuracy failure of `resmlp_12_224` and BF16 accuracy failure of `volo_d1_224` in timm.

In this PR, we check whether input is contiguous using the following way:
If it has `FixedLayout`, we know the accurate strides. For `FlexibleLayout`, if its data is a `ComputedBuffer`, we could get the fill order of the buffer to decide whether it's contiguous. For the other cases, we won't use GEMM template as we can't infer whether it's contiguous.

## Additional context
The current GEMM template only supports this case: `input.get_stride()[-1] == 1`. In `resmlp_12_224`, when we run into this check, the layout of `input` is a `FlexibleLayout`. The reason is that when realizing the input which is a `View` IR, the `convert_to_reinterpret_view` call fails:
d14fe3ffed/torch/_inductor/ir.py (L4712-L4715)

And it finally runs into this `copy_input` and returns a `FlexibleLayout`.
d14fe3ffed/torch/_inductor/ir.py (L4722)

When checking its stride, this `FlexibleLayout` indeed satisfies `input.get_stride()[-1] == 1` but it is later decided as a `FixedLayout` with `size = (3072, 196), stride = (1, 3072)`, which is not supported by the GEMM template, thus causing accuracy issue in this model.
The `FlexibleLayout` is converted to `FixedLayout` during [CppPackedGemmTemplate.add_choices](d14fe3ffed/torch/_inductor/mkldnn_lowerings.py (L1051)) which calls [slice_nd](d14fe3ffed/torch/_inductor/codegen/cpp_template_kernel.py (L150)) when rendering the kernel (`slice_nd(X)`). When creating the `SliceView` IR, [as_storage_and_layout](d14fe3ffed/torch/_inductor/ir.py (L2288)) invokes
[decide_layout](d14fe3ffed/torch/_inductor/ir.py (L2135)) and converts it to a `FixedLayout` with `size = (3072, 196), stride = (1, 3072)`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134982
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
2024-09-10 02:47:38 +00:00
a71e5509bc [inductor]Add profiler to operatorbench (#135515)
Add profiling to operatorbench. The new argument `--profile` is added and the profiling trace is like the following figure.
<img width="954" alt="image" src="https://github.com/user-attachments/assets/5b00d6e3-4905-4a77-a5e9-9f62620a5fd5">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135515
Approved by: https://github.com/shunting314
2024-09-10 02:33:30 +00:00
136e28f616 Enable forward AD in functional.affine_grid (#135494)
Fixes #121411
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135494
Approved by: https://github.com/zou3519, https://github.com/soulitzer
2024-09-10 00:07:07 +00:00
39a61795e3 remove amax_ptr from scaled_gemm (#135421)
amax was removed from _scaled_mm by #128683. Remove it from the internal at::cuda::blas::scaled_gemm, as well.  This allows hipBLASLt to find additional solutions rather than forcing amax to be used and then discarding the result.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135421
Approved by: https://github.com/drisspg, https://github.com/eqy
2024-09-09 23:04:36 +00:00
b4feec9782 [xplat][XNNPACK] don't prefer static linkage in xplat for main target (#135529)
Building XNNPACK as a static library has some issues because of multiple global params floating around.

Let's try to get rid of it in xplat and see how it fares.

Differential Revision: [D60776152](https://our.internmc.facebook.com/intern/diff/D60776152/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D60776152/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135529
Approved by: https://github.com/kimishpatel, https://github.com/mcr229, https://github.com/kirklandsign
2024-09-09 22:47:01 +00:00
d81731615f [Dynamo] Adding CallFunctionNoArgsSource and (#135425)
CallFunctionNoArgsGuardAccessor to support torch.cuda.current_device()

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135425
Approved by: https://github.com/anijain2305
2024-09-09 22:46:00 +00:00
e2f9a83b85 [ONNX] Drop final None values as inputs for nodes in exporter graph (#135520)
When value for an optional input is not provided, it is defaulted to `None`, which gets translates to "" in the onnx graph. To avoid this, if we have a list of inputs and the final few are all `None`, strip them in the graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135520
Approved by: https://github.com/justinchuby

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
2024-09-09 22:28:41 +00:00
70a65a8bd5 Revert "NJT <-> padded dense conversions (#125947)"
This reverts commit 09a5e88bef04d5485b70d8f65f46a675aaa52942.

Reverted https://github.com/pytorch/pytorch/pull/125947 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing dynamo test 09a5e88bef, maybe a landrace ([comment](https://github.com/pytorch/pytorch/pull/125947#issuecomment-2339228570))
2024-09-09 22:01:09 +00:00
689d278543 Revert "Add __init__.py to shape inference folder. (#135461)"
This reverts commit dced0d6d9f05f0962f74a3c6227f774111c15715.

Reverted https://github.com/pytorch/pytorch/pull/135461 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it exposes some public function without appropriate doc. I will reopen the issue with hi-prio so that it can be fixed properly ([comment](https://github.com/pytorch/pytorch/pull/135461#issuecomment-2339218382))
2024-09-09 21:55:13 +00:00
9b764491e3 Use upload-artifact@v4.4.0 for create_release.yml (#135528)
Fixes failure: https://github.com/pytorch/pytorch/actions/runs/10780281005/job/29895846007

Due broken sync
```
actions/upload-artifact@v2
and
actions/download-artifact@v4.1.7
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135528
Approved by: https://github.com/kit1980, https://github.com/malfet
2024-09-09 20:48:52 +00:00
cbc6b30a24 Fix broken E2E tests on Linux machines (#135394)
Summary:
I'm not entirely sure why this is failing with an `ImportError` (according to lastnameye a super class of `ModuleNotFoundError`s), but on our E2E tests on Linux machines (but not Macs?), we're seeing the import failure not getting caught --
`ImportError: cannot import name 'parutil' from 'libfb.py' (/data/sandcastle/boxes/eden-trunk-hg-full-fbsource/buck-out/v2/gen/fbsource/d0c916ec8d40ce11/arvr/libraries/ctrl/studies/replay/__ctrl-r__/ctrl-r#link-tree/libfb/py/__init__.py)` from this test run https://www.internalfb.com/sandcastle/workflow/2522015791331601269, an instance of this job:  https://www.internalfb.com/intern/test/844425085172858?ref_report_id=0 is the overall job

Test Plan:
`arc skycastle schedule tools/skycastle/workflows2/ctrl/js_tests.sky:test_js_e2e_replay_tests --sandcastle-spec-overrides '{"type": "fbcode", "unicastle_size": "I1_MEDIUM"}'`
->
https://www.internalfb.com/sandcastle/workflow/256705178764255769

Differential Revision: D62321167

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135394
Approved by: https://github.com/laithsakka
2024-09-09 20:18:08 +00:00
5b368de7f7 Revert "[ONNX] Update fake mode usage in onnx docs (#135512)"
This reverts commit a13c118994b4f118388d97a35abcb91a396cd437.

Reverted https://github.com/pytorch/pytorch/pull/135512 on behalf of https://github.com/davidberard98 due to failing test  https://github.com/pytorch/pytorch/actions/runs/10778813316/job/29891679127 ([comment](https://github.com/pytorch/pytorch/pull/135512#issuecomment-2338999090))
2024-09-09 20:15:12 +00:00
09a5e88bef NJT <-> padded dense conversions (#125947)
This PR:
* Implements the pre-existing `nt.to_padded_tensor(padding_val)` ATen op via the FBGEMM kernel + appropriate view gymnastics (since that kernel only handles 2D values)
* Introduces a new `_nested_from_padded_tensor` op for the reverse conversion, implemented via the reverse FBGEMM kernel + view gymnastics
    * Note: there is currently no public API for this; design booted to a future PR

TODO:
* ~~Propagate min / max sequence length via the new factory function `_nested_from_padded_tensor`~~
* ~~Verify that Inductor does computation fusion via test logic~~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125947
Approved by: https://github.com/soulitzer
2024-09-09 19:37:32 +00:00
a4e6a0b240 [split build] move periodic split builds into own concurrency group (#135510)
To avoid nightly workflows cancelling each other
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135510
Approved by: https://github.com/clee2000, https://github.com/huydhn, https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2024-09-09 19:35:57 +00:00
4ab232d0c4 Fix symbolic number's type and tensor's dtype mismatch bug in Tensor ctor (#135433)
Fixes #135432

In the current implementation, if we try to store a symbolic number in Tensor's constructor, it assumes that the tensor's dtype and the symbolic number's type are matched, which is not the case.

In other words, if we try to store a `SymInt`, current implementation assumes tensor's dtype is `torch.int32`, `torch.int64` or something. And if we try to store a `SymFloat`, it assumes tensor's dtype is `torch.float32` or `torch.float64`. However, the tensor's dtype could also be `torch.float32` or something else when we try to store `SymInt`, which would be wrong.

This PR stores symbolic numbers by tensor's scalar type by wrapping `SymInt` and `SymFoat`'s guarded number into a PyObject.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135433
Approved by: https://github.com/ezyang
2024-09-09 19:32:18 +00:00
2032f107d7 Don't try to tag s390x docker images (#135509)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135509
Approved by: https://github.com/atalman
2024-09-09 19:07:48 +00:00
5f7d956362 Fix bugs blocking flipping the default layout constraint for custom ops (#135391)
Fixes two things:
- For regular PyTorch ops, the default layout constraint tag is always
flexible_layout. This was a bug with #135238
- Mark the new quantized _wrapped_linear_prepack ops as flexible_layout.
  The metas for these are incorrect, I didn't want to fix them (and
  changing the default requires the metas actually be correct).

Test Plan:
- The next PR up in the stack. The PRs are split because the next one is
  riskier.

foo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135391
Approved by: https://github.com/albanD
2024-09-09 18:24:21 +00:00
a13c118994 [ONNX] Update fake mode usage in onnx docs (#135512)
Update fake mode usage in onnx docs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135512
Approved by: https://github.com/justinchuby
2024-09-09 18:10:37 +00:00
21241bfeee [CP] Extend CP to support load-balancing shards (#132442)
This PR extends the current ring attention to support load-balancing shards -- the context/sequence is divided into `2 * world_size` shards and each rank gets `rank` and `(world_size * 2 - rank - 1)` shards. The data re-shuffling is done in the `context_parallel` API.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132442
Approved by: https://github.com/wconstab
2024-09-09 18:04:38 +00:00
73a6fc6e30 Revert "[Inductor] Make static_input_idxs a set for faster lookup (#135314)"
This reverts commit 011cae9570fb3c44b7f6f0c8004c470579ed21da.

Reverted https://github.com/pytorch/pytorch/pull/135314 on behalf of https://github.com/ZainRizvi due to Lint is failing on this file in trunk. See [GH job link](https://github.com/pytorch/pytorch/actions/runs/10777258770/job/29885960050) [HUD commit link](011cae9570) ([comment](https://github.com/pytorch/pytorch/pull/135314#issuecomment-2338678219))
2024-09-09 17:33:01 +00:00
09287e3af4 [MPS] Add regression test for fft.fftfreq (#135440)
The issue reported in #135223 was already solved in #128393. This PR adds a regression test for it.

Fixes #135223

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135440
Approved by: https://github.com/ezyang
2024-09-09 17:12:36 +00:00
16c3b8f87c [AOTI] Fix assert_function call in cpu autotune template (#135086)
Summary: In the ABI-compatible mode, assert_function should be AOTI_TORCH_CHECK.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135086
Approved by: https://github.com/chenyang78, https://github.com/angelayi
ghstack dependencies: #134857
2024-09-09 16:54:12 +00:00
9c6dff4941 [AOTI] Add C shim for aten.mkldnn_rnn_layer in cpp wrapper (#134857)
Summary: Support aten.mkldnn_rnn_layer in the ABI-compatible mode. Because aten.mkldnn_rnn_layer is an aten op, it is easier to add a C shim function for it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134857
Approved by: https://github.com/angelayi
2024-09-09 16:54:12 +00:00
0eb425a563 [Release] Apply Release changes scripts after release 2.4 (#135495)
Based on additional changes required for https://github.com/pytorch/pytorch/pull/128347
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135495
Approved by: https://github.com/kit1980
2024-09-09 16:49:04 +00:00
011cae9570 [Inductor] Make static_input_idxs a set for faster lookup (#135314)
`static_input_idxs` is only used for lookups. With large models, this is a large list. This takes over a millisecond in some cases.

Profile before change:
<img width="824" alt="image" src="https://github.com/user-attachments/assets/002a0775-fd2f-4d27-8cf2-812b502d7d5e">

Profile after change: gaps are smaller, 1ms speedup before launching the cuda graph
<img width="794" alt="image" src="https://github.com/user-attachments/assets/12a0a0b9-2cc1-4d53-ac87-9bd5140a46f5">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135314
Approved by: https://github.com/oulgen
2024-09-09 16:24:58 +00:00
dfb2b661f7 Use float data type for Half var_sum in batchnorm stats updating on CPU (#126525)
Using float data type for Half `var_sum` in batchnorm stats updating on CPU to avoid `var_sum` overflow since the representation range of Half is small.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126525
Approved by: https://github.com/jgong5, https://github.com/peterbell10
2024-09-09 15:31:38 +00:00
5a69e0ebbe [MPS] Update decorator comments with issue ref (#135448)
Updating the comments with references to better places for context now that the bugs have been identified.

xref #135442 #135447 #134184

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135448
Approved by: https://github.com/ezyang
2024-09-09 15:18:52 +00:00
5e145861f2 [ONNX] Improves documentation of ONNX exporter (#135372)
The PR updates the documentation to reflect the changes introduced in pytorch 2.5 and related to onnx exporter.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135372
Approved by: https://github.com/justinchuby

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
2024-09-09 15:09:01 +00:00
c35b953531 Fix wrong error msg (#135423)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135423
Approved by: https://github.com/ezyang
2024-09-09 13:28:31 +00:00
dced0d6d9f Add __init__.py to shape inference folder. (#135461)
Fixes #135196

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135461
Approved by: https://github.com/ezyang
2024-09-09 13:27:58 +00:00
c0436c5701 [inductor][cpp][gemm] fix perf regression xcit_large_24_p8_224 (#134686) (#135438)
Fix #134686.

PR https://github.com/pytorch/pytorch/pull/132729 makes GEMM template faster for one of the GEMMs in xcit_large_24_p8_224:
SingleProcess AUTOTUNE benchmarking takes 1.7088 seconds and 1.9207 seconds precompiling
AUTOTUNE linear_unary(12544x3072, 768x3072, 768)
  cpp_packed_gemm_2 2.9371 ms 100.0%
  _linear_pointwise 3.1584 ms 93.0%

But it is slower than Aten in the e2e run due to different cache behavior. The access to the input data (12544x3072) is LLC latency bound and bottlenecks seen due to the memory synchronization (data transfers and coherence updates across processors). This PR tries to mitigate the problem by cooperatively loading different chunks of input data from different processors that share the input data.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135438
Approved by: https://github.com/leslie-fang-intel
2024-09-09 05:16:02 +00:00
cyy
60e8dc4374 Check function declarations in Caffe2 code (#134925)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134925
Approved by: https://github.com/ezyang
2024-09-09 05:03:29 +00:00
e6c3f58584 Fix example: Address broadcasting error in the addition of `attn_bias… (#135427)
…` and `attn_mask`, and correct device assignment for newly created variables in the method.

Fix example: Address broadcasting error in the addition of `attn_bias` and `attn_mask`, and correct device assignment for newly created variables in the method.

1. Adding `attn_bias += attn_mask` results in a broadcasting error. The expected shape of `attn_bias` is (L, S), so the output should also have the shape (L, S). However, when the input shape is (N, num_heads, L, S), broadcasting occurs, leading to an output shape of (N, num_heads, L, S), which is not desired.
2. `attn_bias` is a newly created variable within the method, but it is not assigned to the correct device.

**This is my retry of PR #130209 . The PR has been merged into commit `d4a79d4a7c746068d25fe5cf9333495561f4ce1f`, but the modifications were overwritten by subsequent commits.**

Co-authored-by: mikaylagawarecki <mikaylagawarecki@gmail.com>
@mikaylagawarecki  provided a more elegant implementation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135427
Approved by: https://github.com/ezyang
2024-09-09 03:47:34 +00:00
90e12cf63d Fix return type of nansum example. (#135435)
One of the examples in the documentation of `torch.nansum` contains a wrong return type. This fixes it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135435
Approved by: https://github.com/ezyang
2024-09-09 03:34:52 +00:00
44c08f4984 [Partitioner] Query whether nodes exist in graph faster (#135316)
Find node if exist in graph.nodes (linked list) take too long time. Using graph._find_nodes_lookup_table (hash table) instead to speed up.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135316
Approved by: https://github.com/ezyang
2024-09-09 03:34:02 +00:00
b6186353c6 enable lazy_init for hpu (#135203)
enables lazy_init for hpu device
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135203
Approved by: https://github.com/ezyang
2024-09-09 03:32:20 +00:00
b7eb7256fb docs: torch.nn.utils.rnn.pack_padded_sequence: docs improve (#135417)
docs: `torch.nn.utils.rnn.pack_padded_sequence`: docs improve

/cc @mikaylagawarecki
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135417
Approved by: https://github.com/ezyang
2024-09-09 03:16:11 +00:00
c1ae78be92 [inductor] calibration inductor windows uts (18/N) (#135449)
skip test_quantized_* UTs of `test/inductor/test_cpu_select_algorithm.py`.
Windows inductor don't support quantize so far.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135449
Approved by: https://github.com/ezyang
2024-09-09 03:10:54 +00:00
defb515306 [NJT]Add permute ops support (#135336)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135336
Approved by: https://github.com/davidberard98
2024-09-08 21:00:41 +00:00
31c4e0d37d [inductor] Cleanup analysis done at lowering time (#135412)
Before this we would take multiple passes over the body of each IRNode as we did lowering.  This combines most analysis into `OpCounterCSE` so it can be done in a single pass.

Before:
![image](https://github.com/user-attachments/assets/0047db09-4258-4491-a9a6-b078e183092a)

After:
![image](https://github.com/user-attachments/assets/1e03adcb-8303-4bb1-8bbb-cc42dacd44d7)

This stack:
![image](https://github.com/user-attachments/assets/d6b50b24-c30c-4d23-8b1a-344b3ba65d7a)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135412
Approved by: https://github.com/oulgen
ghstack dependencies: #135286, #135306, #135377, #135400
2024-09-08 18:02:36 +00:00
53290ca00b [inductor] Refactor BaseSchedulerNode.__init__ (#135400)
Might be a small compile time improvement since we remove a call to extract_read_writes().

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135400
Approved by: https://github.com/oulgen
ghstack dependencies: #135286, #135306, #135377
2024-09-08 18:02:36 +00:00
16f5155992 [inductor] Fast path for extract_read_writes without tracing (#135377)
Before (bottom of stack):
![image](https://github.com/user-attachments/assets/13060ff9-b31d-42a9-8e8f-c50b2bf3dc2f)

After (this PR):
![image](https://github.com/user-attachments/assets/7d190821-b614-46b7-9e9e-9087443df654)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135377
Approved by: https://github.com/oulgen
ghstack dependencies: #135286, #135306
2024-09-08 18:02:32 +00:00
37144be03d [inductor] Remove ReadWrites.op_counts (#135306)
This was (almost) unused.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135306
Approved by: https://github.com/oulgen
ghstack dependencies: #135286
2024-09-08 18:02:28 +00:00
3bdc54ed18 [inductor] Refactor LoopBody.memory_usage (#135286)
This is preparing for some other changes where I speed up extract_read_writes tracing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135286
Approved by: https://github.com/oulgen
2024-09-08 18:02:24 +00:00
cyy
2196f32475 [22/N] Fix clang-tidy warnings in jit (#135319)
Follows #134537
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135319
Approved by: https://github.com/titaiwangms
2024-09-08 17:18:29 +00:00
cfc227ad43 [reland][dtensor] move DTensor to public namespace (#134203)
reland of https://github.com/pytorch/pytorch/pull/133113

I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :(

----

Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes:

* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next PRs)
* To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module

The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203
Approved by: https://github.com/tianyu-l
2024-09-08 17:08:40 +00:00
20cab91a12 [dynamo] Remove skip from jit freeze tests (#135281)
Fixes https://github.com/pytorch/pytorch/issues/119781
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135281
Approved by: https://github.com/zou3519
2024-09-08 15:11:12 +00:00
a6fae2e811 Use BRGEMM for Half flash attention forward kernel (#131879)
Use oneDNN BRGEMM on packed data to get better performance on the 5th generation of Xeon where Intel® Advanced Matrix Extensions (AMX) will have fp16 support, e.g. amx-fp16.
Multiple models have achieved acceleration, for instance, FP16 stable diffusion v2.1 has achieved over 50% improvement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131879
Approved by: https://github.com/jgong5, https://github.com/peterbell10
ghstack dependencies: #131878
2024-09-08 12:32:23 +00:00
042f2f7746 [ONNX] Re-raise the exception if the dynamic shapes cannot be refined (#135418)
Improve error reporting. Otherwise users will just see not being able to refine shapes most of the time.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135418
Approved by: https://github.com/titaiwangms
2024-09-08 05:30:34 +00:00
fd494dd426 Change wrapped_linear_prepack and wrapped_quantized_linear_prepacked to private by adding _ as prefix (#135401)
Summary: In https://github.com/pytorch/pytorch/pull/134232, we added two new ops wrapped_linear_prepack and wrapped_quantized_linear_prepacked. From the review comments and offline discussion, we are changing them to private by adding `_` as prefix

Differential Revision: D62325142

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135401
Approved by: https://github.com/houseroad
2024-09-08 04:16:24 +00:00
8334cb2fb9 remove commented out breakpoints (#135363)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135363
Approved by: https://github.com/oulgen
2024-09-08 02:15:45 +00:00
e72ed4717e [Dynamo] Fix Huggingface PretrainedConfig get non const attr (#135413)
Fixes #135329

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135413
Approved by: https://github.com/anijain2305
2024-09-07 19:16:29 +00:00
3bebc09be9 [FlexAttention] Align the matmul tensorcore usage (#135168)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135168
Approved by: https://github.com/Chillee
2024-09-07 16:33:41 +00:00
a2db22e6bb [inductor] Catch BrokenProcessPool and print a more helpful message. (#135120)
Summary: BrokenProcessPool means a parallel-compile subprocess exited, which we never expect. It's likely due to a crash, so print a more meaningful error message and instructions that it's probably easier to debug by turning off parallel compile. Output looks like:
```
...
  File "/data/users/slarsen/pytorch/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_slarsen/4q/c4qw7xk5lbb7whg5txnk4hwbc7z6kepak3o666tr3d64gcad5r5b.py", line 815, in <module>
    async_compile.wait(globals())
  File "/data/users/slarsen/pytorch/torch/_inductor/async_compile.py", line 265, in wait
    raise RuntimeError(
RuntimeError: A compilation subprocess exited unexpectedly. This is likely due to a crash. To facilitate debugging, you can re-run with TORCHINDUCTOR_COMPILE_THREADS=1 to cause compilation to occur in the main process.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135120
Approved by: https://github.com/Chillee
2024-09-07 16:33:37 +00:00
eac5e12548 [inductor] Move LoopBody to its own file (#135257)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135257
Approved by: https://github.com/oulgen
2024-09-07 16:29:15 +00:00
18479c5f70 [Doc] update max-autotune for CPU (#134986)
The current doc for `max-autotune` is applicable only for GPU. This PR adds the corresponding content for CPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134986
Approved by: https://github.com/jgong5, https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2024-09-07 13:42:40 +00:00
f7c0c06692 Add oneDNN BRGEMM support on CPU (#131878)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131878
Approved by: https://github.com/jgong5, https://github.com/peterbell10
2024-09-07 13:22:30 +00:00
b53d97c7be [Intel GPU] Add XPU memory-related APIs (#129919)
# Motivation
According to https://github.com/pytorch/pytorch/issues/116322, we will help unify the device allocator. So we introduce a simple xpu device allocator only with the key functionality first. And expect to add some memory statistics-related functionality after the unification.
But now, some memory statistic-related APIs listed in https://github.com/pytorch/pytorch/issues/127929 are requested. We need more time to unify the device allocator. In order to facilitate the user experience, we expect to support these memory statistic-related APIs before the unification.

# Additional Context
Fixes: #127929

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129919
Approved by: https://github.com/dvrogozh, https://github.com/abhilash1910, https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/albanD
ghstack dependencies: #130923
2024-09-07 11:15:17 +00:00
6c1da66407 [Reland] Refactor caching device allocator utils (#130923)
# Motivation
Following [[RFC] Intel GPU Runtime Upstreaming for Allocator ](https://github.com/pytorch/pytorch/issues/116322), this PR aims to refactor caching device allocator utils to improve code reuse usage.
This is the first PR, we could prepare some follow-up PRs continuing to refactor the device caching allocator.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130923
Approved by: https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/albanD, https://github.com/eqy
2024-09-07 11:14:17 +00:00
d7c97e7245 [inductor][cpp][gemm] cache blocking config for dynamic shapes (#133538)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133538
Approved by: https://github.com/leslie-fang-intel
ghstack dependencies: #135277, #133447

Co-authored-by: Wu, Chunyuan <chunyuan.wu@intel.com>
2024-09-07 11:09:30 +00:00
be9f4ffe88 [inductor][cpp][gemm] enable dynamic M for k-slicing (#133447)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133447
Approved by: https://github.com/leslie-fang-intel
ghstack dependencies: #135277

Co-authored-by: Wu, Chunyuan <chunyuan.wu@intel.com>
2024-09-07 11:09:30 +00:00
692faa9bc6 [inductor][cpp][gemm] reduce memory alloc overhead by allocating local acc once per thread (#135277)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135277
Approved by: https://github.com/leslie-fang-intel

Co-authored-by: Wu, Chunyuan <chunyuan.wu@intel.com>
2024-09-07 11:09:25 +00:00
32f3af72b7 [ONNX] Support FakeTensor in ONNXProgram (#135399)
Sync with https://github.com/justinchuby/torch-onnx/compare/v0.1.20...v0.1.21 to support FakeTensors in ONNXProgram. Specifically, this PR implements the `apply_weights` method to allow users to supply a dictionary of concrete tensors to replace FakeTensors in the exported model weights.

An error is raised when users try to serialize a FakeTensor to avoid segfaults.

Also fixed a bug in `.save()` when `keep_initializers_as_inputs` is True and `include_initializers` is False.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135399
Approved by: https://github.com/titaiwangms
2024-09-07 04:48:18 +00:00
ebab5c85c4 [FlexAttention] Skip very small block size unit tests on H100 due to Triton bug (#135393)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135393
Approved by: https://github.com/BoyuanFeng
2024-09-07 04:35:22 +00:00
3d734d837b [ONNX] Handle mixed sequence inputs properly (#135378)
Previously, when an input contains a mixture of `Value` and python constants like `[SymbolicTensor('sym_size_int_3', type=Tensor(INT64), shape=[], producer=node_Shape_0, index=0), 512]`, we get errors like

```pytb
Traceback (most recent call last):
  File "/Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/_building.py", line 367, in _call_op
    converted_named_inputs = _process_python_constants_and_sequences(
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/_building.py", line 275, in _process_python_constants_and_sequences
    raise TypeError(
TypeError: Constant input '[SymbolicTensor('sym_size_int_3', type=Tensor(INT64), shape=[], producer=node_Shape_0, index=0), 512]' of type '<class 'list'>' is not supported
```

This PR updates Sequence handling to support this case, as well as variadic inputs and ONNX Sequence inputs.

Synced from https://github.com/justinchuby/torch-onnx/pull/187
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135378
Approved by: https://github.com/titaiwangms
2024-09-07 03:07:39 +00:00
c92227c41a [quant][pt2e] fix placeholder typo and related quantization tests (#135379)
A previous typo on "placeholder" and related tests in quantization are fixed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135379
Approved by: https://github.com/jerryzh168
2024-09-07 02:31:43 +00:00
e6a0221fc6 [Inductor] Optionally allow padding on non-GPU devices (#135280)
This is the OSS component of a larger MTIA diff.

Currently, Inductor disables padding for non-GPU devices. We need to change this behavior to enable padding on MTIA.

This PR adds a config option to enable padding on the CPU, or any other non-GPU device. In the future, we might want to enable padding on all devices by default. However, that might require supporting device-dependent padding defaults, since CPUs will likely use different settings than H100 GPUs.

Differential Revision: D61038114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135280
Approved by: https://github.com/jfix71, https://github.com/shunting314
2024-09-07 02:19:14 +00:00
a6b9d444fb [ONNX] Refactor exporter errors (#135180)
Refactor exporter errors to combine old errors and new errors for API consistency.

This PR also

1. Removes the `_C._check_onnx_proto(proto)` call in the old exporter. We don't need the ONNX checker because it is limited.
2. Removes the `OnnxExporterError` defined in the dynamo module. This class unnecessarily stores the onnx program object, making it very bulky. Instead, we revert to use the plain OnnxExporterError defined in the `errors` module and use it as the base class for all errors.
3. Continues to expose `OnnxExporterError` in `torch.onnx` and the rest of the errors in `torch.onnx.errors`.
4. Removes the `CheckerError` and `InvalidExportOptionsError` from `torch.onnx`. This is BC breaking but should have low impact.
5. I did not rename existing errors out of compatibility considerations, even though `ExporterError` would have been more succinct.

Fixes https://github.com/pytorch/pytorch/issues/135125
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135180
Approved by: https://github.com/titaiwangms
2024-09-07 00:50:15 +00:00
d42b0c8f22 Add release matrix for 2.5 (#135383)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135383
Approved by: https://github.com/huydhn
2024-09-07 00:49:53 +00:00
941d094dd1 [Dynamo][DTensor] Fixes SymNodeVariable() is not a constant error in Compiled DDP + TP unit test (#135315)
Before the fix, the unit test will fail at forward Dynamo tracing:
```
  File "/data/users/willfeng/pytorch/test/distributed/_composable/test_replicate_with_compiler.py", line 415, in test_ddp_tp
    loss = compiled_replicate_model(data).sum()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...
torch._dynamo.exc.InternalTorchDynamoError: SymNodeVariable() is not a constant

from user code:
   File "/data/users/willfeng/pytorch/torch/distributed/tensor/parallel/_data_parallel_utils.py", line 34, in _unflatten_tensor
    result = DTensor.from_local(
```
After the fix, the compilation fails at a later step (Compiled Autograd tracing), due to needing "pre-dispatch tracing of backward graph" feature (see details at https://github.com/pytorch/pytorch/issues/127797#issuecomment-2291695474).

I believe this PR is a net improvement, because it should also fix the 1D Traceable FSDP2 failure case on internal models (https://github.com/pytorch/pytorch/issues/130978#issuecomment-2319476690), which is much harder to build a minimal unit test for.

Fixes https://github.com/pytorch/pytorch/issues/130978.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135315
Approved by: https://github.com/bdhirsh
2024-09-07 00:11:25 +00:00
b1a934741e Change test_constant_prop_preserve_metadata (#135268)
Summary: In new export_for_training, "stack_trace" does not exist in node meta anymore.

Test Plan:
```
buck run fbcode//mode/dev-nosan fbcode//caffe2/test:quantization_pt2e -- -r test_constant_prop_preserve_metadata
```

Reviewed By: angelayi

Differential Revision: D62219974

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135268
Approved by: https://github.com/angelayi
2024-09-07 00:02:35 +00:00
0c661f3e1a [Split Build] Refactor split build binary builds into their own workflows and move split build binary builds to periodic (#134624)
As we need to move split build binary tests from trunk to periodic this pr, refactors those jobs out into its own workflow to achieve this.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134624
Approved by: https://github.com/malfet
2024-09-06 23:57:56 +00:00
2c7e314803 [Inductor][CPP] Fix the issue of view dtype (#135301)
**Summary**
Fix issue: https://github.com/pytorch/pytorch/issues/135160, it's a regression introduced by https://github.com/pytorch/pytorch/pull/134569, where the dtype of `to_dtype_bitcast` was incorrectly handled when using the scalarize implementation.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_view_dtype
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135301
Approved by: https://github.com/jgong5, https://github.com/jansel
2024-09-06 23:36:44 +00:00
ead4407f57 [inductor] Fix loop split optimization (#135303)
Fix https://github.com/pytorch/pytorch/issues/135274.

Improve the check whether the div expr matches: add a check whether `split_var` is in `original_body.iter_vars`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135303
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
2024-09-06 23:06:25 +00:00
2f5b40c099 [aoti test] Disable FP8 funz dtypes in fp8 runtime check test (#135373)
Fixing https://github.com/pytorch/pytorch/issues/126734

Key is the funz FP8 types are for AMD only.

source: https://github.com/openxla/stablehlo/blob/main/rfcs/20230321-fp8_fnuz.md

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135373
Approved by: https://github.com/chenyang78
2024-09-06 23:05:47 +00:00
993b5647ab [export] fix placeholder name collision tests by removing map call (#135366)
The current test is failing because of the current unstable state of map. torch.compile and non-strict export are taking two seperate routes unlike cond and while_loop. This pr fix the test it self. We'll fix map in follow up PRs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135366
Approved by: https://github.com/angelayi
2024-09-06 22:02:50 +00:00
2ab26806f1 Require tlparse for failing tests in test_structured_trace.py (#135376)
Summary: These tests are currently failing internally. Per discussion, skip if tlparse is unavailable

Test Plan:
```
feature remove tlparse
buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/dynamo:test_dynamo -- --run-disabled --regex test_structured_trace.py
feature install tlparse
buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/dynamo:test_dynamo -- --run-disabled --regex test_structured_trace.py
```

Differential Revision: D62310342

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135376
Approved by: https://github.com/ezyang
2024-09-06 21:53:41 +00:00
b1612569f6 [BE] Clarify defaulting behavior in optimizer (#135384)
Fixes #135340

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135384
Approved by: https://github.com/drisspg, https://github.com/jainapurva
2024-09-06 21:52:55 +00:00
dc0e818738 [FR] Automatically infer a common filename prefix (#135158)
Save the annoyance of specifying this on the command line each time
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135158
Approved by: https://github.com/fduwjj, https://github.com/c-p-i-o
ghstack dependencies: #135157
2024-09-06 21:44:27 +00:00
06e414d7fe [FR] Make trace_dir a required argument (#135157)
Ensures users get a clean error if they forget to specify the dir, and
improves the help message.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135157
Approved by: https://github.com/c-p-i-o, https://github.com/fduwjj
2024-09-06 21:44:27 +00:00
a681260caf Revert "[ONNX] Refactor exporter errors (#135180)"
This reverts commit 5eebd9315a72422d59b6f8d8ca8e4e573e231d5c.

Reverted https://github.com/pytorch/pytorch/pull/135180 on behalf of https://github.com/clee2000 due to I think this broke test_public_bindings.py::TestPublicBindings::test_correct_module_names [GH job link](https://github.com/pytorch/pytorch/actions/runs/10743909338/job/29800779403) [HUD commit link](5eebd9315a), possibly a landrace with the PR that landed before it ([comment](https://github.com/pytorch/pytorch/pull/135180#issuecomment-2334844191))
2024-09-06 21:39:18 +00:00
95e976a63f [dynamo] recursively skip frames when Dynamo cache limit is hit (#135144)
Fixes https://github.com/pytorch/pytorch/pull/135144 and [T197117723](https://www.internalfb.com/intern/tasks/?t=197117723).

In general, adds `SkipCodeRecursiveException` to Dynamo - when raised in Dynamo, convert_frame will return a `skip_code_recursive_flag` back to C Dynamo, signaling it to skip the current frame and all recursive calls.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135144
Approved by: https://github.com/jansel, https://github.com/anijain2305
2024-09-06 21:38:53 +00:00
306ac44eaa [ez][TD] Fix request for issue body returns None (#135389)
I assumed it would be empty string if the body is empty, but its just None
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135389
Approved by: https://github.com/malfet
2024-09-06 21:02:01 +00:00
a7643baceb Revert expectFailureIf condition on tests with torch.compile on Windows (#134759)
Fixes #134716

This PR reverts some changes introduced in 6eae569546 (#133987)

torch.compile is not available on Windows, tests should be expected to fail.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134759
Approved by: https://github.com/malfet
2024-09-06 20:51:55 +00:00
a4030e37be [dynamo] reland map/zip iterator related changes (#135074)
Differential Revision: [D62211019](https://our.internmc.facebook.com/intern/diff/D62211019)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135074
Approved by: https://github.com/jansel, https://github.com/anijain2305, https://github.com/mlazos
2024-09-06 20:38:02 +00:00
22e1fb6faa [test][easy] Add debug utils for cpu select algorithm test (#135038)
Summary: Add debug utils to debug a flaky test in fbcode ci.

Some context: https://github.com/pytorch/pytorch/pull/126545

Test Plan: ci

Differential Revision: D62005445

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135038
Approved by: https://github.com/jgong5, https://github.com/XuehaiPan
2024-09-06 20:30:49 +00:00
2a4890e315 [ONNX] Clean up the missed lines from previous PRs (#135368)
Some missed deleted lines

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135368
Approved by: https://github.com/justinchuby
2024-09-06 20:27:52 +00:00
3ce433aef2 [TCPStore] use wait counters (#135283)
This replaces the existing TCPStore counters with the new shared wait counters. There's no users of the tcpstore counters so should be completely safe to remove.

Test plan:

Existing tests + build

There's no OSS backend for wait counters so can't write any tests with them currently.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135283
Approved by: https://github.com/c-p-i-o
2024-09-06 19:54:25 +00:00
7f2d20e687 Run all autograd node post hooks (#134728)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134728
Approved by: https://github.com/albanD, https://github.com/soulitzer
2024-09-06 19:44:28 +00:00
32fd29c1ea [ONNX] Properly handle Attributes in traceable functions (#135367)
Previously the attributes were sent in as Attr objects even when we call the function as a plain Python function. Turning them into python objects.

From https://github.com/justinchuby/torch-onnx/pull/186
Related https://github.com/microsoft/onnxscript/issues/1846

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135367
Approved by: https://github.com/justinchuby
2024-09-06 19:35:22 +00:00
5eebd9315a [ONNX] Refactor exporter errors (#135180)
Refactor exporter errors to combine old errors and new errors for API consistency.

This PR also

1. Removes the `_C._check_onnx_proto(proto)` call in the old exporter. We don't need the ONNX checker because it is limited.
2. Removes the `OnnxExporterError` defined in the dynamo module. This class unnecessarily stores the onnx program object, making it very bulky. Instead, we revert to use the plain OnnxExporterError defined in the `errors` module and use it as the base class for all errors.
3. Continues to expose `OnnxExporterError` in `torch.onnx` and the rest of the errors in `torch.onnx.errors`.
4. Removes the `CheckerError` and `InvalidExportOptionsError` from `torch.onnx`. This is BC breaking but should have low impact.
5. I did not rename existing errors out of compatibility considerations, even though `ExporterError` would have been more succinct.

Fixes https://github.com/pytorch/pytorch/issues/135125
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135180
Approved by: https://github.com/titaiwangms
2024-09-06 19:10:56 +00:00
a15aabc975 Add MaskedTensor passthrough: unfold, F.Unfold, F.Fold, stack (#125262)
Hi,
I noticed the `unfold` operator was missing on MaskedTensor.

I tested that my change works when calling unfold and backward on a `MaskedTensor` but I didn't find the tests for the dispatch of such operation. Where is it?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125262
Approved by: https://github.com/cpuhrsch
2024-09-06 19:06:23 +00:00
b143426db3 [Inductor] Use argument names as the key for the constants dict and the signature dict (#135170)
Referencing how triton constructs these dictionaries

ca3fb5f6fa/python/triton/runtime/jit.py (L639)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135170
Approved by: https://github.com/htyu
2024-09-06 19:05:00 +00:00
13ba0a2e5c Run bypassed graph compile outside the except block to avoid chaining of exceptions (#135175)
Fixes #135172

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135175
Approved by: https://github.com/masnesral, https://github.com/ezyang
2024-09-06 19:03:57 +00:00
8520ce5f78 Fix incorrect trace of post-accumulate grad hook on tensor with zero dims (#135226)
Fix incorrect trace of post-accumulate grad hook on tensor with zero dimensions

Fixes #135207

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135226
Approved by: https://github.com/xmfan
2024-09-06 18:19:54 +00:00
196748d491 [elastic] support local_addr across all rendezvous impls (#135262)
Summary:
There was a regression introduced in https://github.com/pytorch/pytorch/pull/125743 that made `local_addr` no longer used. This fixes that by passing `local_addr` to `RendezvousStoreInfo.build` everywhere it's used.

This also fixes a number of tests allowing them to be run in parallel which hugely sped up the testing cycle as this change touches many different rendezvous implementations. This required a few fixes in unrelated tests.

Test Plan:
Added tests for the common rendezvous implementations that `local_addr` to prevent future regressions.

```
buck2 test @//mode/dev-nosan fbcode//caffe2/test/distributed/elastic/... fbcode//caffe2/torch/distributed/elastic/... -- --stress-runs 3
```

To vet the parallelism changes I also ran with 3 stress runs each to identify flakiness caused by parallelism.

Differential Revision: D62256407

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135262
Approved by: https://github.com/fduwjj, https://github.com/wz337
2024-09-06 17:55:43 +00:00
177e4f4218 remove _check call on item() for torch.istft (#135234)
Fixes #135014

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135234
Approved by: https://github.com/tugsbayasgalan
2024-09-06 17:31:25 +00:00
3988b3468b [aoti][easy] remove breakpoint() in wrapper.py (#134807)
Differential Revision: D61687146

Remove an unintended breakpoint in code.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134807
Approved by: https://github.com/YUNQIUGUO
2024-09-06 17:25:05 +00:00
04118d8617 [export] Record the global torch version in serialization. (#135243)
Summary: In general I think it will be useful to also record the global torch version in the EP, so that we can track them in the logging in addition to the schema version.

Test Plan: CI

Reviewed By: henryoier

Differential Revision: D62252626

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135243
Approved by: https://github.com/yushangdi
2024-09-06 17:02:06 +00:00
24482e5c68 [torch][fx] Set maximum warning count during fx.Graph.lint (#135069)
Summary:
resnet152 spent about 15 minutes writing warning messages in _unlift
during `to_executorch` because they're all written to unbuffered stderr
by the `warnings` module.

These warnings are almost always about get_attr nodes referencing a
non-existent name:
```lang=py
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
  'not reference an nn.Module, nn.Parameter, or buffer, which is '
  'what \'get_attr\' Nodes typically target'
)
```
I'm not aware of a way to configure the warnings module to write this out
at most once, so I'm just going to disable the lint for now.

Test Plan:
Re-ran resnet152 with Executorch and the XNNPackBackend, it is much faster now

Differential Revision: D62156090

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135069
Approved by: https://github.com/yushangdi
2024-09-06 16:41:59 +00:00
c0ec599f27 Update submodule ideep to include aarch64 change (#134897)
This PR is per ARM request, which is in https://github.com/intel/ideep/issues/334.

Context for the request is: Arm team has upstreamed the dynamic quantization changes, all the PRs were merged (torch, ideep, oneDNN), but without this ideep submodule update, the feature will not work. The change is isolated to only matmul operator and quantization path alone.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134897
Approved by: https://github.com/jgong5, https://github.com/atalman, https://github.com/snadampal
2024-09-06 16:40:26 +00:00
7074de43c0 Porting to GCC 15 (#135188)
uint8_t is found on cstdint header

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135188
Approved by: https://github.com/Skylion007
2024-09-06 16:16:53 +00:00
771dcce11d [AOTI][Tooling][6/n] Fix long dtype input tensors calling mean() in aoti_torch_print_tensor_handle (#135072)
Differential Revision: D61635232

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135072
Approved by: https://github.com/hl475, https://github.com/ColinPeppler
2024-09-06 15:59:32 +00:00
de74aafff4 error on exporting ScriptModule (#135302)
Test Plan: added test

Differential Revision: D62279179

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135302
Approved by: https://github.com/yushangdi
2024-09-06 15:12:40 +00:00
ad29a2c0dc Add Inductor config for default stride behavior (#135238)
By default, Inductor is allowed to manipulate the layout
(strides+storage offset) of input tensors to custom operators.

We want to change it so that the default is that Inductor should respect
the stride order of input tensors to custom operators.

This PR adds a config to toggle the behavior, in the next PR up we'll
change the default. We also make the following changes:
- We add a new operator Tag (flexible_layout), which means that
inductor is allowed to manipulate the layout. When we flip the default,
users can specify they want the old behavior by using this tag.

This is a reland of https://github.com/pytorch/pytorch/pull/126986,
which was previously reverted due to silent incorrectness. We've since
fixed the silent incorrectness
(https://github.com/pytorch/pytorch/pull/133639)

Test Plan:
- new test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135238
Approved by: https://github.com/albanD
2024-09-06 14:48:24 +00:00
3a9e33dca8 [torchelastic] Don't do signal handling when off the main thread (#135088)
Summary:
In multiprocessing, signal handling is not possible if the thread is not the main thread. This resulted in the following error:
> "ValueError('signal only works in main thread of the main interpreter')"

To address this issue, the diff checks whether the thread is the main thread and, if not, skips signal handling.

Test Plan:
Before this change, MAST job failed:
https://fburl.com/mlhub/iq2m10v8

With this change, MAST job succeeded:
https://fburl.com/mlhub/q6kb8343

Differential Revision: D62166943

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135088
Approved by: https://github.com/d4l3k
2024-09-06 14:47:03 +00:00
a086882d72 [inductor][triton] mark workspace args as mutated (#134648)
SplitScan makes use of a workspace arg that needs to be zeroed before it is used - then, it is used to communicate between thread blocks during the triton kernel implementation. It is mutated during during the execution of the kernel, so it should be marked as such.

Before this PR, it is not marked as mutated; AFAIK this is fine during normal execution, but during autotuning it causes problems. The workspace starts off zeroed (as expected), but during autotuning the kernel will be executed multiple times and the workspace does not get re-set between executions, resulting in incorrect data. If the data is used for indexing, then you can fail device-side asserts (and the results after the initial run (with autotuning) could be wrong). The test added in this PR repros the issue when the fix is removed.

When we mark the arg as mutated, then the arg gets cloned before autotuning, so that the arg passed to the kernel during autotuning will always be zeroed as expected.
804852c1f9/torch/_inductor/runtime/triton_heuristics.py (L685-L689)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134648
Approved by: https://github.com/peterbell10, https://github.com/jansel
2024-09-06 14:23:37 +00:00
84ae6b7d6b AOTDispatcher: limit cases when we detach() graph inputs to non-leaves (#134193)
This PR is slightly a revival / update to the discussion from https://github.com/pytorch/pytorch/pull/98960:

Part of FSDP2's tracing strategy right now is that:

(1) it is painful/difficult to handle the case where we have multiple graph input tensors that are aliased to each other and at least one of them is duplicated

(2) we already have longstanding in logic to remove duplicate input tensors from the graph in dynamo. Morally, FSDP2 gives us duplicate input tensors in the backward graph for every `unsharded_param`, because we have (a) the `unsharded_param` being closed over by the backward hook to resize/allgather, and (b) the same `unsharded_param` being saved for backward by autograd (we now guarantee in the partitioner that we will always save the base tensor for backward and recompute views)

(3) However, we were still seeing cases where the `unsharded_param` showed up twice in the backward graph inputs, as distinct tensor objects (with different python ids) instead of being true duplicates that dynamo can de-dup.

It turns on that this was because we were `.detach()`ing the `unsharded_param` in AOTDispatcher before plumbing it through the compiled forward (and so autograd would save a detach'd version of the `unsharded_param`). This is precisely because of the logic from https://github.com/pytorch/pytorch/pull/98960.

However, re-reading the detailed comments, it seems unnecessary to do a detach() on a graph input that is a (leaf) `nn.Parameter`, even if it happens to get no gradients in the backward. Since it is a leaf, we don't have to worry about the autograd engine "continuing to backprop through the graph beyond the current tensor" (the leaf has no other grad_fn for autograd to backprop through).

So this PR makes us a bit less aggressive about calling detach() on inputs: we only do it when:

(1) our graph input statically will get a `None` gradient (and also has no metadata mutations, the existing state)

(2) **and** our graph input is a non-leaf tensor (so detach()ing is actually required to prevent autograd from incorrectly backpropping past the non-leaf.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134193
Approved by: https://github.com/yf225

Co-authored-by: Will Feng <yf225@cornell.edu>
2024-09-06 14:06:48 +00:00
60a097a071 [CD] Update binary_linux_test.sh to include calling builder smoke test (#133869)
Run smoke test

Fixes #1969

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133869
Approved by: https://github.com/atalman

Co-authored-by: Andrey Talman <atalman@fb.com>
2024-09-06 13:27:24 +00:00
13bae39e22 [inductor] [cpp] improve cache blocking for is_dynamic_M (#131306)
## Performance
Models with >= 3% performance speedup are listed below:

### AMP single-thread dynamic shape (measured on CPU with AMX support)
No regressions

| Model Family | Model Name | Speedup |
|--------------|------------|---------|
torchbench | soft_actor_critic| 3%

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131306
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel
ghstack dependencies: #135275

Co-authored-by: Jiong Gong <jiong.gong@intel.com>
2024-09-06 13:21:24 +00:00
4ef6c05f65 [inductor][cpp][gemm] fix autotune runtime error from linear_binary fusion (#135275)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135275
Approved by: https://github.com/leslie-fang-intel
2024-09-06 13:21:23 +00:00
d6b9bd3e60 Also handle compiler collective when input variable doesn't exist on all ranks (#135147)
Internal xref:
https://fb.workplace.com/groups/3095840833991792/permalink/3810738595835342/

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135147
Approved by: https://github.com/jansel
2024-09-06 13:18:36 +00:00
d0591f4658 Ignore fresh unbacked when doing recursive make_fx inside HOPs (#135053)
Internal xref: https://fb.workplace.com/groups/6829516587176185/posts/7705964779531357/

This now also incorporates a test from https://github.com/pytorch/pytorch/pull/133585 (which it fixes) and the prep PR https://github.com/pytorch/pytorch/pull/134407 Including the PR desc from that:

I am trying to fix a problem reported by user in [fb.workplace.com/groups/6829516587176185/permalink/7705964779531357](https://fb.workplace.com/groups/6829516587176185/permalink/7705964779531357/) The summary of this problem is that when we do collect metadata analysis in AOTAutograd, we accumulate pending unbacked symbols which are going to be discarded at the end of the trace. However, if we do a recursive make_fx inside tracing, as occurs with torch.cond, we end up seeing that there are pending unbacked symbols that aren't associated with a binding, even though it's spurious (they've leaked into the inner make_fx call from the outer AOTAutograd analysis).

In https://github.com/pytorch/pytorch/pull/133588 I tried to just prevent adding the symbols to the pending list at all in the first place. But this itself caused some problems which were fixed in https://github.com/pytorch/pytorch/pull/124785 . The problem fixed in that PR is that when we allocate tangents that have unbacked size, something prevented them from having correct unbacked SymInts when ignore fresh unbacked SymInts was enabled. So I had patched it at the time by just not suppressing pending symbols and clearing them out some other way.

I think... I was wrong in that PR? That is to say, it was OK to avoid putting the fresh unbacked symbols in the pending list; the real problem was suppressing unbacked renamings. But there doesn't seem to be a good reason to suppress these; this PR shows that it doesn't actually fail any tests if you do these anyway. Intuitively, this makes sense, because you can't trigger renamings unless you're actually adding unbacked symbols to the pending set.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135053
Approved by: https://github.com/ydwu4
2024-09-06 13:13:15 +00:00
b5dea061c8 check compilation status before query cudnn version in conv (#135332)
This PR is created for fixing the https://github.com/pytorch/pytorch/issues/135322.  The cudnn compilation status should be check firstly before querying version, otherwise, conv may trigger runtimeerror before any check in other non-cuda backends.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135332
Approved by: https://github.com/EikanWang, https://github.com/atalman
2024-09-06 12:50:04 +00:00
041960a1ce [Dynamo] Automatically in-graph traceable tensor subclass ctors (#135151)
Fixes https://github.com/pytorch/pytorch/issues/114389

Previously, dynamo would attempt to trace through the `__init__` of traceable tensor subclasses, since their constructors are AOT dispatcher traceable by definition, dynamo should automatically put these in the graph like we do for any other tensors. Not doing this is difficult because dynamo would need to apply mutations post tensor subclass creation in the graph.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135151
Approved by: https://github.com/bdhirsh
2024-09-06 12:23:38 +00:00
67c7924ea1 [inductor] Fix gen_transposed_tile_load_store (#135307)
Recent PR: https://github.com/pytorch/pytorch/pull/131745 bring new VLA logical in cpp codegen. And it will raise build fail error on MSVC and error code is `Compiler Error C2131`: https://learn.microsoft.com/en-us/cpp/error-messages/compiler-errors-1/compiler-error-c2131?view=msvc-170

reproduce UT:
```cmd
pytest test\inductor\test_torchinductor_dynamic_shapes.py -v -k test_large_block_sizes_dynamic_shapes_cpu
```

Original generated code:
```c++
alignas(16) float tmp1[static_cast<int64_t>(((-256LL)*(c10::div_floor_integer(static_cast<int64_t>(ks1), static_cast<int64_t>(16LL)))) + (16LL*ks1))];
```

Changes:
allocate a large-enough fixed-sized buffer.

New genarated code:
```c++
alignas(16) float tmp1[16*16];
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135307
Approved by: https://github.com/jgong5, https://github.com/jansel
2024-09-06 10:44:08 +00:00
217ba7b2ab [Docs] Update FileCheck doc (#135199)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135199
Approved by: https://github.com/soulitzer
2024-09-06 08:18:38 +00:00
758d515d98 [Inductor][CPP] Select tiling factor for lower precision data types (#133830)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133830
Approved by: https://github.com/jgong5, https://github.com/jansel
2024-09-06 08:12:37 +00:00
60d98b4cfb Update torch-xpu-ops pin (ATen XPU implementation) (#135300)
Release cycle for PyTorch 2.5
1. Bugfixing: correct reduction logic in cdist kernel.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135300
Approved by: https://github.com/EikanWang
2024-09-06 07:30:09 +00:00
590a3e9f8a [export][training ir migration] quantized_decomposed.quantize_per_tensor decomposition (#134525)
Summary:
In graph of  TestXNNPACKQuantizer.test_dynamic_linear_with_con test, some quantized_decomposed.quantize_per_tensor.default ops are becoming quantized_decomposed.dequantize_per_tensor.tensor ops when using the new training ir.

This is because we lift params/buffers before calling make_fx. So previously, for the graph that’s passed to make_fx,`graph.L__self___linear1.weight` is a tensor
now in training ir, graph.L__self___linear1.weight is a FakeTensor. This caused the node overload to be different.

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_dynamic_linear_with_conv
```

Differential Revision: D61364547

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134525
Approved by: https://github.com/tugsbayasgalan, https://github.com/jerryzh168
2024-09-06 07:06:06 +00:00
764ee6e3f9 [FlexAttention] Specify padding_value for boundary checked loads (#134573)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134573
Approved by: https://github.com/Chillee
2024-09-06 06:47:26 +00:00
67f98a99a4 [DeviceMesh][Easy] Make RuntimeError a bit more descriptive by including the actual world_size (#135271)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135271
Approved by: https://github.com/fduwjj
2024-09-06 06:23:20 +00:00
e020a8755a [Fix][FR][ez] Remove debugging logs (#135308)
Removing the print added during debugging process.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135308
Approved by: https://github.com/wz337
2024-09-06 06:14:33 +00:00
7ffb3b201c [inductor] Remove LoopBody.reads,writes,other (#135256)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135256
Approved by: https://github.com/oulgen
ghstack dependencies: #135070, #135076, #135082, #135084, #135079, #135235
2024-09-06 06:11:55 +00:00
f946bf88c4 [inductor] Skip retracing an existing LoopBody (#135235)
This is roughly a 7% speedup in inductor compile time for hf_Bert_large.  The time spent in `LoopBody.__init__` improves from 15% to 8% of `fx_codegen_and_compile`.

Before
![image](https://github.com/user-attachments/assets/7de0f28e-35bd-472f-b4be-b52733d2a85c)

After
![image](https://github.com/user-attachments/assets/5f0cf11a-43c5-43ae-b13c-f32383a75a7f)

Overall
![image](https://github.com/user-attachments/assets/6a369d8c-fb5e-4ad2-9504-0fc745ad6568)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135235
Approved by: https://github.com/oulgen
ghstack dependencies: #135070, #135076, #135082, #135084, #135079
2024-09-06 06:11:55 +00:00
66da3b3b2a [fx] Bypass custom __setattr__ in Node.__init__ (#135079)
Before:
![image](https://github.com/user-attachments/assets/5f0a6ae6-6049-44d0-b5f2-a549a23ad97f)

After:
![image](https://github.com/user-attachments/assets/51c9f91b-f8a0-4043-8362-65813feec823)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135079
Approved by: https://github.com/oulgen
ghstack dependencies: #135070, #135076, #135082, #135084
2024-09-06 06:11:46 +00:00
41e653456e [RDP] Fix "No module named 'libfb’" (#135244)
Summary:
D62215095 Introduced an import error to arvr pipelines as the is_fbcode() function does not work as intended.

This changes is_fbcode() to be a much stricter check.

Test Plan:
```
buck2 run arvr/mode/platform010/opt-stripped //arvr/libraries/depthlink/clients/mr_replay:pipeline_runner -c bolt.use_eva3_sim=True -- --config_file arvr/libraries/depthlink/clients/mr_replay/configs/runner_config.yaml --features DEPTH
```

Differential Revision: D62237502

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135244
Approved by: https://github.com/aorenste
2024-09-06 04:52:31 +00:00
e40a0a9359 Add randomness checking for sdpa vmap (#135176)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135176
Approved by: https://github.com/zou3519
2024-09-06 04:50:49 +00:00
c05a7adb36 [inductor][debug] fix draw_buffers (#135266)
**Before:**
![image](https://github.com/user-attachments/assets/aac756f3-1349-4647-9da3-87cf105cf647)

**After:**
<img width="791" alt="image" src="https://github.com/user-attachments/assets/d72c663c-e598-42fa-ac40-9e58956f1ec1">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135266
Approved by: https://github.com/yf225
2024-09-06 04:12:41 +00:00
5f57be7571 [Distributed] Change function call in test to non-deprecated to eliminate warning (#134938)
Migrate function call in test to eliminate warning message in below and reduce the chance of test fail when methods removed

-  from deprecated `save_state_dict` change to `save`
-  from deprecated `load_state_dict` change to `load`

Warning message:
```bash
pytorch/test/distributed/checkpoint/test_fsdp_model_state.py:37: FutureWarning: `save_state_dict` is deprecated and will be removed in future versions.Please use `save` instead.

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134938
Approved by: https://github.com/wz337, https://github.com/fegin
2024-09-06 03:25:09 +00:00
29d72c1100 [inductor] check intel compiler minimal version (#135209)
On Windows: early version icx has `-print-file-name` issue, and can't preload correctly for inductor. Add minimal version check for Intel compiler.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135209
Approved by: https://github.com/ezyang
2024-09-06 03:21:07 +00:00
3b1a334c0f [Inductor][CPP] Avoid mistake wgt tensor delete (#135100)
**Summary**
Fix issue: https://github.com/pytorch/pytorch/issues/134998: Previously, we only checked if the `get_attr` FX node for the weight had a single user node. However, two `get_attr` nodes may share the same tensor and should not be deleted in such cases. In this PR, we add the count of users for tensor along with the num of users for nodes to decide whether this tensor can be deleted or not.

**TestPlan**
```
 python test/inductor/test_cpu_select_algorithm.py -k test_linear_wgt_multi_users
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135100
Approved by: https://github.com/jgong5
2024-09-06 03:13:36 +00:00
07689a38bf [Inductor] Fix AOT weight alignment issue on CPU (#135205)
**Summary**
Fix issue: https://github.com/pytorch/pytorch/issues/135027. On CPU, the `consts_size` used to generate `_binary_constants_bin_start` is not padded to `ALIGN_BYTES`, while `serialized_weights` is, causing a failure in the 16K alignment check.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135205
Approved by: https://github.com/jgong5, https://github.com/desertfire
2024-09-06 03:06:51 +00:00
06a7dc21c1 Remove dead expect_rational (#135105)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135105
Approved by: https://github.com/malfet
2024-09-06 02:57:27 +00:00
d9a18173fa Report qualname of exception type rather than <class 'RuntimeError'> (#135146)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135146
Approved by: https://github.com/Skylion007, https://github.com/albanD, https://github.com/yanboliang
ghstack dependencies: #135148, #135145
2024-09-06 02:56:50 +00:00
d8543e3162 Include exception type qualname when rewrapping InternalTorchDynamoError (#135145)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135145
Approved by: https://github.com/drisspg, https://github.com/anijain2305
ghstack dependencies: #135148
2024-09-06 02:56:50 +00:00
ad01fc194d Consolidate raise and rewrap raise error branches (#135148)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135148
Approved by: https://github.com/anijain2305, https://github.com/albanD, https://github.com/yanboliang, https://github.com/malfet
2024-09-06 02:56:46 +00:00
e162414963 add instrumentation of CCA stats for reserved and allocated memory size (#135231)
As titled
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135231
Approved by: https://github.com/c-p-i-o
2024-09-06 02:48:56 +00:00
9e5a797771 Improve test_public_bindings import module error reporting (#135258)
Error was hard to understand without message. Render it now. See https://github.com/pytorch/pytorch/pull/135259 for it in action.

Example failure:

```
2024-09-05T20:04:45.3022000Z FAILED [5.9524s] test_public_bindings.py::TestPublicBindings::test_modules_can_be_imported - AssertionError: String comparison failed: '' != "torch._logging.scribe failed to import w[112 chars].py)"
2024-09-05T20:04:45.3025413Z + torch._logging.scribe failed to import with error ImportError: cannot import name 'TypeAlias' from 'typing' (/opt/conda/envs/py_3.9/lib/python3.9/typing.py)
2024-09-05T20:04:45.3026990Z
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135258
Approved by: https://github.com/albanD
2024-09-06 02:40:03 +00:00
b46a1b9e2d Use Python 3.9 on all libtorch jobs (#135245)
Part of the migration py3.8->3.9

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135245
Approved by: https://github.com/izaitsevfb
2024-09-06 02:27:22 +00:00
9688014820 aarch64: extend matmul heuristic checks to all neoverse platforms (#134548)
for aarch64 neoverse platforms there are two gemm backends available
for matmul operator on PyTorch: (1) Arm Compute Library and (2) OpenBLAS.
While Arm Compute Library provides better performance over OpenBLAS,
it has overhead for the kernel launch time, and hence we use OpenBLAS
for smaller tensor compute. The heuristic was originally implemented for
neoverse_v1. This commit extends the heuristic to other neoverse platforms

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134548
Approved by: https://github.com/malfet
2024-09-06 01:40:50 +00:00
8f6e73f068 [ONNX] Enable experimental exporter logic to dynamo_export and support refine dynamic_shapes (#134976)
(1) Enable experimental exporter logic to dynamo_export
(2) Refine dynamic shapes and retry export in export strategies
(3) Delete `torch_export_graph_extractor` and use the new export logic
(4) Disable ExportedProgram test in `test_fx_onnx_with_onnxruntime.py`, as ONNXProgram is different now.

Fixes https://github.com/pytorch/pytorch/issues/126479
Fixes #135183
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134976
Approved by: https://github.com/justinchuby
2024-09-06 01:29:56 +00:00
1e57ef08fa [AOTI] Support MKLDNN qconv ops in cpp wrapper (#134795)
Summary: Similar to https://github.com/pytorch/pytorch/pull/134475, support qconv in the ABI-compatible mode for cpp-wrapper Inductor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134795
Approved by: https://github.com/leslie-fang-intel, https://github.com/chunyuan-w, https://github.com/angelayi
ghstack dependencies: #134475, #134783
2024-09-06 01:01:53 +00:00
614b86d602 [AOTI] Support MKLDNN qlinear ops in cpp wrapper (#134783)
Summary: Similar to https://github.com/pytorch/pytorch/pull/134475, support qlinear in the ABI-compatible mode for cpp-wrapper Inductor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134783
Approved by: https://github.com/leslie-fang-intel, https://github.com/chunyuan-w, https://github.com/angelayi
ghstack dependencies: #134475
2024-09-06 01:01:53 +00:00
0b96dfb736 [AOTI] Support MKLDNN conv ops in cpp wrapper (#134475)
Summary: Partially fix https://github.com/pytorch/pytorch/issues/123040. In the ABI-compatible mode, MKLDNN fallback ops do not have C shim implementations and thus need to go through the custom ops launch path. Other MLKDNN ops will be fixed in following PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134475
Approved by: https://github.com/leslie-fang-intel, https://github.com/chunyuan-w, https://github.com/angelayi
2024-09-06 01:01:53 +00:00
62b221d5cc Add Percentages to Function Events (#135155)
Summary: Users have recently asked that the profiler contains self/total CPU and device percentages to FunctionEvents so that teams can process the data procedurely. Some of it could be done mathematically via subroutines but since we already have the information in the _build_table, lets build it there.

Test Plan: Check that we have the same table as before but also check that the parameters we check also have the expected values

Differential Revision: D62210351

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135155
Approved by: https://github.com/shanw-meta, https://github.com/kit1980
2024-09-06 00:39:11 +00:00
66dd4577b1 Track base of FunctionalTensor in inference mode. (#135141)
The idea behind the tracking is the following, whenever we see a tensor if the tensors is a root tensors (does not have any view metas ) when we consider is as the base of the all the tensors that shares its storage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135141
Approved by: https://github.com/zou3519
2024-09-06 00:10:25 +00:00
cyy
cc28634172 [Submodule] Bump pybind11 to v2.13.5 (#135202)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135202
Approved by: https://github.com/Skylion007
2024-09-06 00:09:00 +00:00
c83cdf068b [DTensor] Fix view op replicating on tensor dim when the size of the tensor dim = 1 (#135054)
We found a corner case that when a tensor dimension is 1, calling `view(1)` would result in an unexpected replication (see case 1 below). When the tensor dimension to shard is not 1, no matter whether the tensor dimension is evenly-shardable across the mesh dimension, it won't cause an implicit replication behind the scenes if view doesn't change the size of the given tensor dimension (see case 2 and 3).

When the tensor dimension to shard is of size 1, it is not being added to shardable_dims here:
https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/_view_ops.py#L518

```
# uneven case where the size of the tensor dimension to shard is 1
p = torch.randn(1,2)
mesh = init_device_mesh(“cuda”, (2,))
dtensor = distribute_tensor(p, mesh, [Shard(0)])
t = dtensor.view(1, 2)
# this would result in replication, meaning t is now replicated across all ranks.

# uneven case where the size of the tensor dimension to shard is not 1
p = torch.randn(3, 2)
mesh = init_device_mesh(“cuda”, (2,))
dtensor = distribute_tensor(p, mesh, [Shard(0)])
t = dtensor.view(3, 2) # this would not result in replication.
# this would not result in replication, meaning t stays as sharded.

# even case
p = torch.randn(2,2)
dtensor = distribute_tensor(p, mesh, [Shard(0)])
t = dtensor.view(2, 2)
# this would not result in replication, meaning t stays as sharded.
```

Differential Revision: [D62155606](https://our.internmc.facebook.com/intern/diff/D62155606)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135054
Approved by: https://github.com/tianyu-l, https://github.com/wanchaol
2024-09-06 00:03:54 +00:00
28ccfba248 [ONNX] Delete ONNXProgramSerializer (#135261)
Fixes #135182

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135261
Approved by: https://github.com/justinchuby
2024-09-05 23:52:51 +00:00
b2386bdca1 [debug] Add helper to run cProfile on a function (#135084)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135084
Approved by: https://github.com/oulgen
ghstack dependencies: #135070, #135076, #135082
2024-09-05 23:41:30 +00:00
bdfc8d9f96 [fx] Don't use generators in map_aggregate (#135082)
While the generators avoid a copy, they are slow.

Before:
![image](https://github.com/user-attachments/assets/70a55a9a-0595-4105-b0ab-22cf77c7409c)

After:
![image](https://github.com/user-attachments/assets/cecb9c59-ae36-47de-8b08-cab2c7cb3d57)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135082
Approved by: https://github.com/oulgen
ghstack dependencies: #135070, #135076
2024-09-05 23:41:30 +00:00
70779dded8 [fx] Compile time optimization in Node.__update_args_kwargs (#135076)
Before this we took two passes over all of the args.

Before:
![image](https://github.com/user-attachments/assets/24ce5628-03f4-4983-9f2d-5ddf0ca5816e)

After:
![image](https://github.com/user-attachments/assets/c9681aa2-32f0-4f6b-a598-fc6f90ffafb5)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135076
Approved by: https://github.com/Chillee
ghstack dependencies: #135070
2024-09-05 23:41:30 +00:00
ea231300d1 [inductor] Improve compile time regression from MemoryDep.normalize (#135070)
Possible fix for #135056

Before
![image](https://github.com/user-attachments/assets/3962cb85-e808-4fd4-991f-471ff5ef7eae)

After
![image](https://github.com/user-attachments/assets/2322d48d-6518-4518-baca-336027b5cda8)

Measured based on:
```
python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cuda --training --only hf_Bert_large --stats -n1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135070
Approved by: https://github.com/Chillee
2024-09-05 23:41:30 +00:00
8f66995459 Revert "Support rolling over a percentage of workflows (#134816)"
This reverts commit fc890b55b51098437b6149abf1026a8b2aaee389.

Reverted https://github.com/pytorch/pytorch/pull/134816 on behalf of https://github.com/malfet due to Causes lint to intermittently fail ([comment](https://github.com/pytorch/pytorch/pull/134816#issuecomment-2332902609))
2024-09-05 23:39:41 +00:00
144fde4fd2 [MPS] Add support for autocast in MPS (#99272)
Fixes https://github.com/pytorch/pytorch/issues/88415

Need to run inductor/test_cpu_select_algorithm

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99272
Approved by: https://github.com/malfet

Co-authored-by: Siddharth Kotapati <skotapati@apple.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: Roy Hvaara <roy@lightyear.no>
2024-09-05 23:23:17 +00:00
43f4947d44 fix fake tensor tolist implementation (#135131)
Summary:
When exporting for training with `tolist`, we do not hit `FunctionalTensor.tolist` since we do not functionalize. Unfortunately, this means we hit `FakeTensor.tolist`, which creates unbacked symints that are not backed by proxies.

Rather than trying to patch up this low-level implementation, we replace it with essentially what `FunctionalTensor.tolist` does, which is higher-level: we essentially desugar to `item()` calls and let it take care of unbacked symints.

Test Plan:
Some expected failures are gone now.
Also found a test for `tolist` that was written when `FunctionalTensor.tolist` was implemented but not really doing much; repurposed it now to exercise more modes.

Differential Revision: D62197742

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135131
Approved by: https://github.com/ezyang
2024-09-05 23:20:31 +00:00
65e1c34061 [rfc] scuba for flight recorder (#134794)
Summary: Record flight recorder status in a scuba table.

Test Plan: Testing with timing out a job. Will post results soon.

Differential Revision: D61729221

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134794
Approved by: https://github.com/fduwjj
2024-09-05 23:18:10 +00:00
830247c355 [Intel Triton] Update Intel Triton to release/2.5.0 (#134074)
This PR relands https://github.com/pytorch/pytorch/pull/134053

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134074
Approved by: https://github.com/EikanWang
2024-09-05 22:46:31 +00:00
4262755b5a [cond] fix typo in cond codegen (#134708)
As titled.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134708
Approved by: https://github.com/jansel
2024-09-05 22:38:24 +00:00
3825607144 Add torch._logging.scribe (#135224)
See https://github.com/pytorch/pytorch/pull/135138 for a usage example. Meta only, see https://docs.google.com/document/d/1JpbAQvRhTmuxjnKKjT7qq57dsnV84nxSLpWJo1abJuE/edit#heading=h.9wi46k7np6xw for context

fbscribelogger is a library that allows us to write to scribe, which is Meta's logging infrastructure, when you have appropriate access token (this token is available for jobs running on main, as well as authorized jobs with the ci-scribe label). The resulting data is accessible via Scuba (a real time in-memory database) and Hive (a more traditional SQL persisted database).

Here's the motivating use case. Suppose there is somewhere in PyTorch's codebase where you'd like to log an event, and then you'd like to find all the situations where this log is called. If PyTorch is rolled out to our internal users, we have some FB-oriented APIs (like torch._utils_internal.signpost_event) with which you can do this. But you have to actually land your PR to main, wait for it to be ingested to fbcode, and then wait for us to actually roll out this version, before you get any data. But what if you want the results within the next few hours? Instead, you can use torch._logging.scribe to directly write to our logging infrastructure *from inside CI jobs.* The most convenient approach is to log unstructured JSON blobs to `open_source_signpost` (added in this PR; you can also add your own dedicated table as described in the GDoc above). After adding logging code to your code, you can push your PR to CI, add 'ci-scribe' label, and in a few hours view the results in Scuba, e.g., (Meta-only) https://fburl.com/scuba/torch_open_source_signpost/z2mq8o4l If you want continuous logging on all commits on master, you can land your PR and it will be continuously get logging for all CI runs that happen on main.

Eventually, if your dataset is important enough, you can consider collaborating with PyTorch Dev Infra to get the data collected in our public AWS cloud so that OSS users can view it without access to Meta's internal users. But this facility is really good for prototyping / one-off experiments. It's entirely self serve: just add your logging, run your PR CI with ci-scribe, get results, do analysis in Scuba.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135224
Approved by: https://github.com/Skylion007
2024-09-05 22:37:13 +00:00
eqy
3c8f71ff93 [cuDNN][64-bit indexing] cuDNN v9.3+ supports non-batch-splittable convolutions with > 2**31 elements (#134890)
For longstanding issues such as #95024

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134890
Approved by: https://github.com/Skylion007
2024-09-05 22:22:45 +00:00
fc890b55b5 Support rolling over a percentage of workflows (#134816)
In order to support adding a rollover percentage, this ended up being a complete rewrite of runner_determinator.py.

Details of the new format are in the comments up top.

On the plus side, this now includes some unit tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134816
Approved by: https://github.com/PaliC, https://github.com/zxiiro
2024-09-05 22:21:45 +00:00
058a69d91a [fbcode][dynamo] Turn on guard_nn_modules using justknobs_check (#134928)
As Title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134928
Approved by: https://github.com/ezyang
2024-09-05 22:05:54 +00:00
6c5920d515 Tune int8 AMX WoQ micro-kernel for CPU (#134832)
This patch prevents performance regression against the default ATen implementation for LLaMA 3.1 int8 GPTQ WoQ workload.

Uses AMX micro-kernel only if `M` >= `block_m`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134832
Approved by: https://github.com/jgong5
2024-09-05 22:01:14 +00:00
116fd474da [export] Expand coverage to more copied sym ops for unflattener. (#135119)
Test Plan:
buck2 test 'fbcode//mode/opt' fbcode//torchrec/ir/tests:test_serializer -- --run-disabled

```
File changed: fbcode//caffe2/torch/export/unflatten.py
Buck UI: https://www.internalfb.com/buck2/2e0377e7-e2b6-4bd0-8133-a787245165a0
Test UI: https://www.internalfb.com/intern/testinfra/testrun/5066549824883887
Network: Up: 0B  Down: 0B
Jobs completed: 16. Time elapsed: 10.2s.
Tests finished: Pass 6. Fail 0. Fatal 0. Skip 0. Build failure 0
```

Differential Revision: D62190172

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135119
Approved by: https://github.com/yushangdi
2024-09-05 21:58:20 +00:00
a5d70cf545 [PyTorch] Add isfinite to BFloat16-math.h (#135052)
Missing function from <cmath>.

Differential Revision: [D62148884](https://our.internmc.facebook.com/intern/diff/D62148884/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135052
Approved by: https://github.com/PaliC, https://github.com/albanD
ghstack dependencies: #135031
2024-09-05 21:50:36 +00:00
7fe819d917 [PyTorch] Fix -Wshadow -Werror build in BFloat16-inl.h (#135031)
`float_t` is required to exists in C99 math.h, which causes -Wshadow to fire. We don't need the alias, fortunately.

Differential Revision: [D62135908](https://our.internmc.facebook.com/intern/diff/D62135908/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135031
Approved by: https://github.com/albanD
2024-09-05 21:48:21 +00:00
f63571060c Revert "Use actions/upload-artifact@v4.4.0 for rest of workflows (#135264)"
This reverts commit 9c0b03020b7204ca5d5dbe18174bab005f79c47b.

Reverted https://github.com/pytorch/pytorch/pull/135264 on behalf of https://github.com/atalman due to broke CI ([comment](https://github.com/pytorch/pytorch/pull/135264#issuecomment-2332674607))
2024-09-05 21:43:05 +00:00
38fead8f7c [hop] preserve metadata in re-tracing hop subgraph by running with interpreter (#135159)
In this way, the interpreter.run can preserve the current metadata of subgraphs correctly when tracing the subgraphs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135159
Approved by: https://github.com/tugsbayasgalan
2024-09-05 21:36:56 +00:00
24a223c49d Run inductor micro benchmark on x86 metal runner (#135042)
This enables inductor micro benchmark on CPU (x86):

* Running on AWS metal runner for more accurate benchmark
* I add a new `arch` column, which will be either x86_64 or arm64 for CPU or GPU name for GPU.  We can use this later to differentiate between different setup, i.e. cuda (a100) vs cuda (a10g) or cpu (x86_64) vs cpu (arm64)

The next step would be to run this one cpu arm64, and cuda (a10g).

### Testing
Here is the CSV results from my test run https://github.com/pytorch/pytorch/actions/runs/10709344180

```
name,metric,target,actual,dtype,device,arch,is_model
mlp_layer_norm_gelu,flops_utilization,0.8,17.36,bfloat16,cpu,x86_64,False
gather_gemv,memory_bandwidth(GB/s),990,170.80,int8,cpu,x86_64,False
gather_gemv,memory_bandwidth(GB/s),1060,204.78,bfloat16,cpu,x86_64,False
Mixtral-8x7B-v0.1,token_per_sec,175,26.68,int8,cpu,x86_64,True
Mixtral-8x7B-v0.1,memory_bandwidth(GB/s),1130,171.91,int8,cpu,x86_64,True
Mixtral-8x7B-v0.1,compilation_time(s),162,47.36,int8,cpu,x86_64,True
gemv,memory_bandwidth(GB/s),870,236.36,int8,cpu,x86_64,False
gemv,memory_bandwidth(GB/s),990,305.71,bfloat16,cpu,x86_64,False
Llama-2-7b-chat-hf,token_per_sec,94,14.01,bfloat16,cpu,x86_64,True
Llama-2-7b-chat-hf,memory_bandwidth(GB/s),1253,185.18,bfloat16,cpu,x86_64,True
Llama-2-7b-chat-hf,compilation_time(s),162,74.99,bfloat16,cpu,x86_64,True
Llama-2-7b-chat-hf,token_per_sec,144,25.09,int8,cpu,x86_64,True
Llama-2-7b-chat-hf,memory_bandwidth(GB/s),957,165.83,int8,cpu,x86_64,True
Llama-2-7b-chat-hf,compilation_time(s),172,70.69,int8,cpu,x86_64,True
layer_norm,memory_bandwidth(GB/s),950,172.03,bfloat16,cpu,x86_64,False
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135042
Approved by: https://github.com/yanboliang
2024-09-05 21:31:36 +00:00
e4920a1364 [Traceable FSDP2][Dynamo] allow tracing through auto_functionalized HOP (#135169)
If an `auto_functionalized` HOP is included in backward graph due to activation checkpointing, we will run into a scenario where Compiled Autograd Dynamo tracing will need to trace through the `auto_functionalized` HOP. This PR adds support for it.

Test commands:
- `pytest -rA test/inductor/test_compiled_autograd.py::TestCompiledAutograd::test_trace_auto_functionalized`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135169
Approved by: https://github.com/zou3519
2024-09-05 21:22:45 +00:00
bc5ecf83d7 [training ir migration] Fix quantization tests (#135184)
Summary:
Fixed some quantization tests for new training ir:

Fix batch norm node pattern matcher. In training ir, we have `aten.batch_norm` node instead of `aten._native_batch_norm_legit` and `aten._native_batch_norm_legit_no_training`.

Test Plan:
```
buck run fbcode//mode/dev-nosan fbcode//caffe2/test:quantization_pt2e
```

Reviewed By: tugsbayasgalan

Differential Revision: D62209819

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135184
Approved by: https://github.com/tugsbayasgalan
2024-09-05 21:19:28 +00:00
e55c0f59e5 Revert "[Reland] Refactor caching device allocator utils (#130923)"
This reverts commit 9809080b9ed657a8c0ea0383be7cbdce3a26e05e.

Reverted https://github.com/pytorch/pytorch/pull/130923 on behalf of https://github.com/kit1980 due to breaking internal builds - Error: Relocation overflow has occured ([comment](https://github.com/pytorch/pytorch/pull/130923#issuecomment-2332640961))
2024-09-05 21:16:14 +00:00
a4cf9653ee Revert "Remove Caffe2 code from tool scripts (#134941)"
This reverts commit c818ecd1698a28d9fadf4a81453a89914b18374a.

Reverted https://github.com/pytorch/pytorch/pull/134941 on behalf of https://github.com/kit1980 due to breaking internal builds - The path `caffe2/operators/hip/gather_op.cuh` does not exist ([comment](https://github.com/pytorch/pytorch/pull/134941#issuecomment-2332636624))
2024-09-05 21:12:54 +00:00
9c0b03020b Use actions/upload-artifact@v4.4.0 for rest of workflows (#135264)
To be consistent with https://github.com/pytorch/pytorch/pull/135263 and rest of workflows. Use v4.4.0.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135264
Approved by: https://github.com/kit1980, https://github.com/malfet
2024-09-05 21:05:06 +00:00
034717a029 [ROCm] remove triton-rocm commit pin and merge pins with triton.txt (#133438)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133438
Approved by: https://github.com/jithunnair-amd, https://github.com/malfet

Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com>
2024-09-05 20:36:45 +00:00
9c38b00999 [export] Add ability to run eagerly on UnflattenedModule (#133996)
Summary:
Added the contextmanager, `_disable_interpreter`, which is meant to put around a call to `unflatten`. This will generate an UnflattendModule and sub-InterpreterModules which will not use torch.fx.Interpreter to run eagerly. We want to have this as a state of the module instead of a contextmanager around running the module because it's not clear where we are calling the unflattened module.

This seems to improve the performance: https://fb.workplace.com/groups/1075192433118967/posts/1473590629945810/?comment_id=1473621763276030

Test Plan: CI

Differential Revision: D60939034

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133996
Approved by: https://github.com/pianpwk
2024-09-05 20:28:42 +00:00
8efe547046 Use actions/upload-artifact@v4.4.0 for triton builds (#135263)
Same as: https://github.com/pytorch/pytorch/pull/135139
Fixes upload failure: https://github.com/pytorch/pytorch/actions/runs/10722567217/job/29748125015
fix regression introduced by https://github.com/pytorch/pytorch/pull/135068

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135263
Approved by: https://github.com/kit1980, https://github.com/huydhn
2024-09-05 20:03:39 +00:00
82d00acfee Allow cross-device copies for cpu scalars in refs (#135140)
This copies our eager-mode behavior where someone can do torch.add(a, b, out=c)
where a and b are CPU scalar tensors and c is a CUDA tensor.

Fixes https://github.com/pytorch/pytorch/issues/121619 by side effect (we get into a situation where we're writing a CPU scalar into a FakeTensor that is actually a meta tensor)

Test Plan:
- new test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135140
Approved by: https://github.com/williamwen42, https://github.com/yanboliang
2024-09-05 19:08:48 +00:00
098431a29d Update Resize.cpp with new device type (#135117)
Update Resize.cpp with new device type

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135117
Approved by: https://github.com/egienvalue
2024-09-05 18:53:13 +00:00
be660ea2d3 [PT2] Directly set meta.val in group_batch_fusion_aten (#135078)
Summary: instead of using FakeTensorProp after the pass

Differential Revision: D62162640

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135078
Approved by: https://github.com/frank-wei
2024-09-05 18:17:06 +00:00
52c7c89ea4 [Inductor][CPP] Leverage full bits for BF16/FP16 vectorization (#126502)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126502
Approved by: https://github.com/jgong5, https://github.com/jansel
2024-09-05 17:17:46 +00:00
1efd341d15 [fake_tensor] Move unrecognized_type NotImplemented before ConstProp (#135033)
We should not try to do ConstProp on the unrecognized types (e.g. Subclasses).
In case of those types throwing NotImplemented will jump to the next torch_dispatch.

Test:
```
 python test/functorch/test_aotdispatch.py -k test_aot_test_subclasses_with_tensor_factories
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135033
Approved by: https://github.com/zou3519, https://github.com/bdhirsh
2024-09-05 17:09:41 +00:00
a096f2899d Add torch.serialization.skip_data context manager (#134504)
## Semantic

The semantic is
(1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint).

```python
import torch
import torch.nn as nn

sd = nn.Linear(3, 5).state_dict()
with torch.serialization.skip_data():
    torch.save(sd, 'foo.pt')
print(torch.load('foo.pt', weights_only=True))
```

(2)  With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor)

```python
import torch
import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode

with FakeTensorMode():
    m = nn.Linear(3, 5, dtype=torch.float16, device='cuda')

sd = m.state_dict()
with torch.serialization.skip_data(materialize_fake_tensors=True):
    torch.save(sd, 'bla.pt')
print(torch.load('bla.pt', weights_only=True))
# OrderedDict([('weight', tensor([[0., 0., 0.],
#        [0., 0., 0.],
#        [0., 0., 0.],
#        [0., 0., 0.],
#        [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))])

```

## Follow Ups

- [ ] `torch.load` semantic for skip_data context manager
- [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass)

Differential Revision: [D62238610](https://our.internmc.facebook.com/intern/diff/D62238610)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134504
Approved by: https://github.com/albanD
2024-09-05 16:53:39 +00:00
dbeb8a1691 Render log filepaths that are not anchored in torch's directory in a reasonable way (#135165)
For example, if I do TORCH_LOGS=fbscribelogger I'll get:

```
I0904 17:59:07.567000 3672513 fbscribelogger/__init__.py:161] stop
```

instead of

```
I0904 12:46:15.332000 2930287 ../../../../../home/ezyang/local/a/pytorch-env/lib/python3.10/site-packages/fbscribelogger/__init__.py:161] stop
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135165
Approved by: https://github.com/Skylion007
2024-09-05 16:48:09 +00:00
b1f72e2984 Gradient scaler for DTensor (#132816)
Solve the request [here](https://github.com/pytorch/pytorch/issues/120003#issuecomment-2248805798).
Enable DTensor input in gradient scaler's APIs, especially on `.unscale_()`
Related dispatch strategy is added to accept DTensor input.

To enable found_inf to conduct reduce action across devices, we add allreduce at dispatch with args after dispatch strategy and kernel.
Since `aten._amp_foreach_non_finite_check_and_unscale_.default` is an inplace_op, grad_scale as the arg[0] with be inplaced, so that redesign a strategy or refactoring the kernel would not help

Test files are testing 2 parts under 1-d(dp) and 2-d(dp,tp) cases:
1. whether the non-inf values unscaled
2. whether all DTensors at each device could found inf even not at their device.
3. If inf not found, will new parameters generates
4. if inf found, will scale be updated

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132816
Approved by: https://github.com/XilunWu, https://github.com/weifengpy, https://github.com/wanchaol
2024-09-05 16:44:32 +00:00
bb3c2408f4 [inductor][test] in test_unbacked_symints, replace inductor's skipCUDAIf with common device type's skipcudaif (#133936)
Differential Revision: D61506212

Use `skipCUDAIf` from `torch.testing._internal.common_device_type` if we create the test class with `instantiate_device_type_tests`.

`instantiate_device_type_tests` would make sure the class has attr device_type, which works with`skipCUDAIf` from `torch.testing._internal.common_device_type`.

Also skipping test_vertical_pointwise_reduction_fusion for cpu test class, since the test expects cuda.

FAILED [0.0026s] test/inductor/test_unbacked_symints.py::TestUnbackedSymintsCPU::test_vertical_pointwise_reduction_fusion_cpu - AttributeError: 'TestUnbackedSymintsCPU' object has no attribute 'device'

repro:
```
CUDA_VISIBLE_DEVICES="" pytest test/inductor/test_unbacked_symints.py -k cpu -v
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133936
Approved by: https://github.com/ColinPeppler, https://github.com/desertfire
2024-09-05 16:40:14 +00:00
2c99f17a32 Implement VariableTracker.python_type() (#134215)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134215
Approved by: https://github.com/amjames, https://github.com/jansel
2024-09-05 16:35:47 +00:00
0043dcd79e Switch torch pt2e xnnpack tests to use export_for_training (#134788)
Migrate all the callsites inside the pt2e XNNPACK tests to use export_for_training.

Differential Revision: D61994553

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134788
Approved by: https://github.com/mergennachin
2024-09-05 16:11:18 +00:00
2e2fb668fa Upgrade expecttest to 0.2.1 (#135136)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135136
Approved by: https://github.com/albanD, https://github.com/atalman, https://github.com/Skylion007
2024-09-05 16:05:35 +00:00
9d24f945ba [CI] Use larger instance for building triton whl (#135201)
When running CI jobs of "Build Triton Wheels", it failed due to the lack of resources. This PR uses a larger runner to avoid these issues.

The failure message is like:

```
Process completed with exit code 137.
```

Related running actions:
Failed actions: https://github.com/pytorch/pytorch/actions/runs/10714445036
Success actions: https://github.com/pytorch/pytorch/actions/runs/10716710830

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135201
Approved by: https://github.com/chuanqi129, https://github.com/atalman
2024-09-05 14:36:23 +00:00
ecbd715363 [Intel GPU][Windows] Fix overriding default CMAKE_CXX_FLAGS (#135093)
The root cause is that `/EHsc` is part of the default `CMAKE_CXX_FLAGS` in CMake.
Fix to not override the default `CMAKE_CXX_FLAGS`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135093
Approved by: https://github.com/EikanWang, https://github.com/atalman
2024-09-05 12:52:43 +00:00
58f2477a26 [Dynamo] Support builtin function frozenset (#134563)
Support builtin function frozenset in dynamo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134563
Approved by: https://github.com/anijain2305, https://github.com/EikanWang, https://github.com/jansel
2024-09-05 12:15:10 +00:00
43dcb4bb61 Revise CPU vectorization ISA support API (#135075)
Revising (mostly renaming) CPU vectorization ISA support API (non-frontend-user-facing). Also added AVX512_BF16 ISA detection API.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135075
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/ezyang
2024-09-05 12:14:56 +00:00
50d1e37079 [AOTI] Fix a unbacked symint retrieve bug (#134670)
Summary: Fix https://github.com/pytorch/pytorch/issues/134081. When a unbacked symint is computed as the shape of a tensor from a tuple, generated C++ code needs to use std::get<> to extract the tensor.

Differential Revision: [D62142113](https://our.internmc.facebook.com/intern/diff/D62142113)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134670
Approved by: https://github.com/angelayi, https://github.com/22quinn, https://github.com/chenyang78
2024-09-05 11:34:14 +00:00
b99ef1a02e Update torch-xpu-ops pin (ATen XPU implementation) (#135185)
Release cycle for PyTorch 2.5
1. Update specific AOT targets for Windows. On Windows, AOT target list prefers Intel client GPUs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135185
Approved by: https://github.com/EikanWang
2024-09-05 10:05:23 +00:00
8a5c8e5db9 Update unbacked symints in masked_select more precisely (#134899)
## Summary
At the moment, the fake impl for `masked_select` simply sets the upper range while updating its size-like SymInt to `sys.maxsize`(9223372036854775807, max value for an unsigned int64) if the there are any SymInts in the original input tensor shape. This PR constrains the range more intelligently by using the upper ranges of each SymInt in the input tensor shape.

This solves an issue where an model being lowered to Executorch errors during memory planning because the memory allocated for `masked_select` ended up exceeded the 64-bit address space (`INT_MAX * size(dtype)`).

## Test plan
- Passes existing unit tests (tests case where upper bound is inf)
- Added unit test to verify upper bound reduction calculation
- Tested end-to-end by exporting with TORCH_LOGS="export" and ensuring that the range for `masked_select`'s SymInt size has the correct upper bound
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134899
Approved by: https://github.com/ezyang
2024-09-05 09:01:06 +00:00
c7328dff7f Enhance the stability of the complex divide code (#134647)
In C++, when a floating-point literal (e.g., 3.14) is compared with a variable of type float, the literal is by default interpreted as a double.
```c++
float f = 3.14f;
if (f == 3.14) {
    // Do something
}
```
If a device does not support double, an error will occur.
This PR addresses the issue of complex64 errors on machines that do not support double operations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134647
Approved by: https://github.com/EikanWang, https://github.com/albanD
2024-09-05 08:36:37 +00:00
700 changed files with 21679 additions and 10067 deletions

View File

@ -1,5 +1,5 @@
0.6b
0.7b
manylinux_2_17
rocm6.2
7f07e8a1cb1f99627eb6d77f5c0e9295c775f3c7
e4ab195d2bd19e939c675a13280c29714c6ef9f2cf420690da150fa0cac043b1
9be04068c3c0857a4cfd17d7e39e71d0423ebac2
3e9e1959d23b93d78a08fcc5f868125dc3854dece32fd9458be9ef4467982291

View File

@ -108,10 +108,10 @@ ENV CMAKE_C_COMPILER cc
ENV CMAKE_CXX_COMPILER c++
COPY ./common/install_triton.sh install_triton.sh
COPY ./common/common_utils.sh common_utils.sh
COPY ci_commit_pins/triton-rocm.txt triton-rocm.txt
COPY ci_commit_pins/triton.txt triton.txt
COPY triton_version.txt triton_version.txt
RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi
RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt
RUN rm install_triton.sh common_utils.sh triton.txt triton_version.txt
# Install AOTriton (Early fail)
COPY ./aotriton_version.txt aotriton_version.txt

View File

@ -1 +0,0 @@
21eae954efa5bf584da70324b640288c3ee7aede

View File

@ -1 +1 @@
1b2f15840e0d70eec50d84c7a0575cb835524def
91b14bf5593cf58a8541f3e6b9125600a867d4ef

View File

@ -1 +1 @@
dedb7bdf339a3546896d4820366ca562c586bfa0
5fe38ffd73c2ac6ed6323b554205186696631c6f

View File

@ -4,12 +4,12 @@ set -ex
source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
TARBALL='aotriton.tar.bz2'
TARBALL='aotriton.tar.gz'
# This read command alwasy returns with exit code 1
read -d "\n" VER MANYLINUX ROCMBASE PINNED_COMMIT SHA256 < aotriton_version.txt || true
ARCH=$(uname -m)
AOTRITON_INSTALL_PREFIX="$1"
AOTRITON_URL="https://github.com/ROCm/aotriton/releases/download/${VER}/aotriton-${VER}-${MANYLINUX}_${ARCH}-${ROCMBASE}-shared.tar.bz2"
AOTRITON_URL="https://github.com/ROCm/aotriton/releases/download/${VER}/aotriton-${VER}-${MANYLINUX}_${ARCH}-${ROCMBASE}-shared.tar.gz"
cd "${AOTRITON_INSTALL_PREFIX}"
# Must use -L to follow redirects

View File

@ -12,10 +12,7 @@ conda_reinstall() {
as_jenkins conda install -q -n py_$ANACONDA_PYTHON_VERSION -y --force-reinstall $*
}
if [ -n "${ROCM_VERSION}" ]; then
TRITON_REPO="https://github.com/openai/triton"
TRITON_TEXT_FILE="triton-rocm"
elif [ -n "${XPU_VERSION}" ]; then
if [ -n "${XPU_VERSION}" ]; then
TRITON_REPO="https://github.com/intel/intel-xpu-backend-for-triton"
TRITON_TEXT_FILE="triton-xpu"
else

View File

@ -30,9 +30,14 @@ dill==0.3.7
#Pinned versions: 0.3.7
#test that import: dynamo/test_replay_record.py test_dataloader.py test_datapipe.py test_serialization.py
expecttest==0.1.6
expecttest==0.2.1
#Description: method for writing tests where test framework auto populates
# the expected output based on previous runs
#Pinned versions: 0.2.1
#test that import:
fbscribelogger==0.1.6
#Description: write to scribe from authenticated jobs on CI
#Pinned versions: 0.1.6
#test that import:
@ -332,3 +337,8 @@ onnxscript==0.1.0.dev20240817
#Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal
#Pinned versions:
#test that import:
parameterized==0.8.1
#Description: Parameterizes unittests, both the tests themselves and the entire testing class
#Pinned versions:
#test that import:

View File

@ -1 +1 @@
3.0.0
3.1.0

View File

@ -100,10 +100,10 @@ ARG TRITON
# try to reach out to S3, which docker build runners don't have access
COPY ./common/install_triton.sh install_triton.sh
COPY ./common/common_utils.sh common_utils.sh
COPY ci_commit_pins/triton-rocm.txt triton-rocm.txt
COPY ci_commit_pins/triton.txt triton.txt
COPY triton_version.txt triton_version.txt
RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi
RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt
RUN rm install_triton.sh common_utils.sh triton.txt triton_version.txt
# Install AOTriton
COPY ./aotriton_version.txt aotriton_version.txt

View File

@ -596,6 +596,9 @@ test_single_dynamo_benchmark() {
test_inductor_micro_benchmark() {
TEST_REPORTS_DIR=$(pwd)/test/test-reports
if [[ "${TEST_CONFIG}" == *cpu* ]]; then
test_inductor_set_cpu_affinity
fi
python benchmarks/gpt_fast/benchmark.py --output "${TEST_REPORTS_DIR}/gpt_fast_benchmark.csv"
}

View File

@ -43,6 +43,9 @@ python -m pip install z3-solver==4.12.2.0
# Install tlparse for test\dynamo\test_structured_trace.py UTs.
python -m pip install tlparse==0.3.25
# Install parameterized
python -m pip install parameterized==0.8.1
run_tests() {
# Run nvidia-smi if available
for path in '/c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe' /c/Windows/System32/nvidia-smi.exe; do

View File

@ -119,6 +119,11 @@ fi
# Test the package
/builder/check_binary.sh
if [[ "\$GPU_ARCH_TYPE" != *s390x* && "\$GPU_ARCH_TYPE" != *xpu* && "\$GPU_ARCH_TYPE" != *rocm* && "$PACKAGE_TYPE" != libtorch ]]; then
# Exclude s390, xpu, rocm and libtorch builds from smoke testing
python /builder/test/smoke_test/smoke_test.py --package=torchonly --torch-compile-check disabled
fi
# Clean temp files
cd /builder && git clean -ffdx

View File

@ -90,7 +90,7 @@ fi
if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*rocm.* && $(uname) == "Linux" ]]; then
TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}; ${TRITON_CONSTRAINT}"
if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then
TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton-rocm.txt)
TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton.txt)
TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}+${TRITON_SHORTHASH}; ${TRITON_CONSTRAINT}"
fi
if [[ -z "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" ]]; then

View File

@ -86,6 +86,18 @@
- pull
- inductor
- name: OSS CI / pytorchbot / slow tests
patterns:
- test/slow_tests.json
approved_by:
- pytorchbot
ignore_flaky_failures: false
mandatory_checks_name:
- EasyCLA
- Lint
- pull
- slow
- name: OSS CI /pytorchbot / Executorch
patterns:
- .ci/docker/ci_commit_pins/executorch.txt

View File

@ -9,6 +9,7 @@ ciflow_push_tags:
- ciflow/inductor-rocm
- ciflow/inductor-perf-compare
- ciflow/inductor-micro-benchmark
- ciflow/inductor-micro-benchmark-cpu-x86
- ciflow/inductor-cu124
- ciflow/linux-aarch64
- ciflow/mps

View File

@ -1,6 +1,7 @@
boto3==1.19.12
hypothesis==6.56.4
expecttest==0.1.6
expecttest==0.2.1
fbscribelogger==0.1.6
librosa>=0.6.2
mpmath==1.3.0
networkx==2.8.7
@ -30,3 +31,4 @@ optree==0.12.1
# NB: test_hparams_* from test_tensorboard is failing with protobuf 5.26.0 in
# which the stringify metadata is wrong when escaping double quote
protobuf==3.20.2
parameterized==0.8.1

View File

@ -15,9 +15,7 @@ REPO_DIR = SCRIPT_DIR.parent.parent
def read_triton_pin(device: str = "cuda") -> str:
triton_file = "triton.txt"
if device == "rocm":
triton_file = "triton-rocm.txt"
elif device == "xpu":
if device == "xpu":
triton_file = "triton-xpu.txt"
with open(REPO_DIR / ".ci" / "docker" / "ci_commit_pins" / triton_file) as f:
return f.read().strip()

View File

@ -325,6 +325,7 @@ def generate_wheels_matrix(
os: str,
arches: Optional[List[str]] = None,
python_versions: Optional[List[str]] = None,
use_split_build: bool = False,
) -> List[Dict[str, str]]:
package_type = "wheel"
if os == "linux" or os == "linux-aarch64" or os == "linux-s390x":
@ -371,7 +372,17 @@ def generate_wheels_matrix(
) and python_version == "3.13":
continue
if use_split_build and (
arch_version not in ["12.4", "12.1", "11.8", "cpu"] or os != "linux"
):
raise RuntimeError(
"Split build is only supported on linux with cuda 12.4, 12.1, 11.8, and cpu.\n"
f"Currently attempting to build on arch version {arch_version} and os {os}.\n"
"Please modify the matrix generation to exclude this combination."
)
# 12.1 linux wheels require PYTORCH_EXTRA_INSTALL_REQUIREMENTS to install
if (
arch_version in ["12.4", "12.1", "11.8"]
and os == "linux"
@ -385,6 +396,7 @@ def generate_wheels_matrix(
"desired_cuda": translate_desired_cuda(
gpu_arch_type, gpu_arch_version
),
"use_split_build": "True" if use_split_build else "False",
"devtoolset": (
"cxx11-abi" if arch_version == "cuda-aarch64" else ""
),
@ -400,7 +412,8 @@ def generate_wheels_matrix(
),
}
)
if arch_version != "cuda-aarch64":
# Special build building to use on Colab. PyThon 3.10 for 12.1 CUDA
if python_version == "3.10" and arch_version == "12.1":
ret.append(
{
"python_version": python_version,
@ -409,40 +422,16 @@ def generate_wheels_matrix(
"desired_cuda": translate_desired_cuda(
gpu_arch_type, gpu_arch_version
),
"use_split_build": "True",
"use_split_build": "True" if use_split_build else "False",
"devtoolset": "",
"container_image": WHEEL_CONTAINER_IMAGES[arch_version],
"package_type": package_type,
"pytorch_extra_install_requirements": (
PYTORCH_EXTRA_INSTALL_REQUIREMENTS[arch_version] # fmt: skip
if os != "linux-aarch64"
else ""
),
"build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}-split".replace( # noqa: B950
"pytorch_extra_install_requirements": "",
"build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}-full".replace( # noqa: B950
".", "_"
),
}
)
# Special build building to use on Colab. PyThon 3.10 for 12.1 CUDA
if python_version == "3.10" and arch_version == "12.1":
ret.append(
{
"python_version": python_version,
"gpu_arch_type": gpu_arch_type,
"gpu_arch_version": gpu_arch_version,
"desired_cuda": translate_desired_cuda(
gpu_arch_type, gpu_arch_version
),
"use_split_build": "False",
"devtoolset": "",
"container_image": WHEEL_CONTAINER_IMAGES[arch_version],
"package_type": package_type,
"pytorch_extra_install_requirements": "",
"build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}-full".replace( # noqa: B950
".", "_"
),
}
)
else:
ret.append(
{
@ -452,6 +441,7 @@ def generate_wheels_matrix(
"desired_cuda": translate_desired_cuda(
gpu_arch_type, gpu_arch_version
),
"use_split_build": "True" if use_split_build else "False",
"devtoolset": (
"cxx11-abi" if arch_version == "cpu-cxx11-abi" else ""
),
@ -467,6 +457,7 @@ def generate_wheels_matrix(
),
}
)
return ret

View File

@ -61,6 +61,7 @@ class BinaryBuildWorkflow:
# Mainly for macos
cross_compile_arm64: bool = False
macos_runner: str = "macos-14-xlarge"
use_split_build: bool = False
def __post_init__(self) -> None:
if self.abi_version:
@ -69,12 +70,20 @@ class BinaryBuildWorkflow:
)
else:
self.build_environment = f"{self.os}-binary-{self.package_type}"
if self.use_split_build:
# added to distinguish concurrency groups
self.build_environment += "-split"
def generate_workflow_file(self, workflow_template: jinja2.Template) -> None:
output_file_path = (
GITHUB_DIR
/ f"workflows/generated-{self.build_environment}-{self.branches}.yml"
)
if self.use_split_build:
output_file_path = (
GITHUB_DIR
/ f"workflows/generated-{self.build_environment}-{self.branches}"
)
with open(output_file_path, "w") as output_file:
GENERATED = "generated" # Note that please keep the variable GENERATED otherwise phabricator will hide the whole file
output_file.writelines([f"# @{GENERATED} DO NOT EDIT MANUALLY\n"])
@ -110,6 +119,20 @@ LINUX_BINARY_BUILD_WORFKLOWS = [
isolated_workflow=True,
),
),
BinaryBuildWorkflow(
os=OperatingSystem.LINUX,
package_type="manywheel",
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
OperatingSystem.LINUX,
use_split_build=True,
arches=["11.8", "12.1", "12.4", "cpu"],
),
ciflow_config=CIFlowConfig(
labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL},
isolated_workflow=True,
),
use_split_build=True,
),
BinaryBuildWorkflow(
os=OperatingSystem.LINUX,
package_type="conda",
@ -162,6 +185,21 @@ LINUX_BINARY_SMOKE_WORKFLOWS = [
),
branches="main",
),
BinaryBuildWorkflow(
os=OperatingSystem.LINUX,
package_type="manywheel",
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
OperatingSystem.LINUX,
arches=["11.8", "12.1", "12.4"],
python_versions=["3.9"],
use_split_build=True,
),
ciflow_config=CIFlowConfig(
labels={LABEL_CIFLOW_PERIODIC},
),
branches="main",
use_split_build=True,
),
BinaryBuildWorkflow(
os=OperatingSystem.LINUX,
package_type="libtorch",

View File

@ -3,49 +3,94 @@
"""
This runner determinator is used to determine which set of runners to run a
GitHub job on. It uses the first comment of a GitHub issue (by default
https://github.com/pytorch/test-infra/issues/5132) as a user list to determine
which users will get their jobs to run on experimental runners. This user list
is also a comma separated list of additional features or experiments which the
user could be opted in to.
https://github.com/pytorch/test-infra/issues/5132) to define the configuration
of which runners should be used to run which job.
The configuration has two parts, the settings and a list of opted-in users,
separated by a line containing "---". If the line is not present, the
settings are considered to be empty with only the second part, the user
list, defined.
The first part is a YAML block that defines the rollout settings. This can be
used to define any settings that are needed to determine which runners to use.
It's fields are defined by the RolloutSettings class below.
The second part is a list of users who are explicitly opted in to the LF fleet.
The user list is also a comma separated list of additional features or
experiments which the user could be opted in to.
The user list has the following rules:
- Users are GitHub usernames with the @ prefix
- If the first line is a "*" then all users will use the new runners
- If the first line is a "!" then all users will use the old runners
- Users are GitHub usernames, which must start with the @ prefix
- Each user is also a comma-separated list of features/experiments to enable
- A "#" prefix indicates the user is opted out of the new runners but is opting
into features/experiments.
- A "#" prefix opts the user out of all experiments
Example user list:
Example config:
# A list of experiments that can be opted into.
# This defines the behavior they'll induce when opted into.
# Expected syntax is:
# [experiment_name]: # Name of the experiment. Also used for the label prefix.
# rollout_perc: [int] # % of workflows to run with this experiment when users are not opted in.
@User1
@User2,amz2023
#@UserOptOutOfNewRunner,amz2023
experiments:
lf:
rollout_percent: 25
---
# Opt-ins:
# Users can opt into the LF fleet by adding their GitHub username to this list
# and specifying experiments to enable in a comma-separated list.
# Experiments should be from the above list.
@User1,lf,split_build
@User2,lf
@User3,split_build
"""
import logging
import os
import random
from argparse import ArgumentParser
from logging import LogRecord
from typing import Any, Iterable
from typing import Any, Dict, Iterable, List, NamedTuple, Tuple
import yaml
from github import Auth, Github
from github.Issue import Issue
WORKFLOW_LABEL_META = "" # use meta runners
DEFAULT_LABEL_PREFIX = "" # use meta runners
WORKFLOW_LABEL_LF = "lf." # use runners from the linux foundation
WORKFLOW_LABEL_LF_CANARY = "lf.c." # use canary runners from the linux foundation
RUNNER_AMI_LEGACY = ""
RUNNER_AMI_AMZ2023 = "amz2023"
GITHUB_OUTPUT = os.getenv("GITHUB_OUTPUT", "")
GH_OUTPUT_KEY_AMI = "runner-ami"
GH_OUTPUT_KEY_LABEL_TYPE = "label-type"
SETTING_EXPERIMENTS = "experiments"
LF_FLEET_EXPERIMENT = "lf"
CANARY_FLEET_SUFFIX = ".c"
class Experiment(NamedTuple):
rollout_perc: float = (
0 # Percentage of workflows to experiment on when user is not opted-in.
)
# Add more fields as needed
class Settings(NamedTuple):
"""
Settings for the experiments that can be opted into.
"""
experiments: Dict[str, Experiment] = {}
class ColorFormatter(logging.Formatter):
"""Color codes the log messages based on the log level"""
@ -172,85 +217,180 @@ def is_exception_branch(branch: str) -> bool:
return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"}
def get_fleet(rollout_state: str, workflow_requestors: Iterable[str]) -> str:
"""
Determines if the job should run on the LF fleet or the Meta fleet
Returns:
The appropriate label prefix for the runner, corresponding to the fleet to use.
This gets prefixed to the very start of the runner label.
"""
def load_yaml(yaml_text: str) -> Any:
try:
if rollout_state[0] == "!":
log.info("LF Workflows are disabled for everyone. Using meta runners.")
return WORKFLOW_LABEL_META
elif rollout_state[0] == "*":
log.info("LF Workflows are enabled for everyone. Using LF runners.")
return WORKFLOW_LABEL_LF
else:
all_opted_in_users = {
usr_raw.strip("\n\t@ ").split(",")[0]
for usr_raw in rollout_state.split()
}
opted_in_requestors = {
usr for usr in workflow_requestors if usr in all_opted_in_users
}
if opted_in_requestors:
log.info(
f"LF Workflows are enabled for {', '.join(opted_in_requestors)}. Using LF runners."
)
return WORKFLOW_LABEL_LF
else:
log.info(
f"LF Workflows are disabled for {', '.join(workflow_requestors)}. Using meta runners."
)
return WORKFLOW_LABEL_META
except Exception as e:
log.error(
f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}"
)
return WORKFLOW_LABEL_META
data = yaml.safe_load(yaml_text)
return data
except yaml.YAMLError as exc:
log.exception("Error loading YAML")
raise
def get_optin_feature(
rollout_state: str, workflow_requestors: Iterable[str], feature: str, fallback: str
def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str]:
"""
Extracts the text with settings, if any, and the opted in users from the rollout state.
If the issue body contains "---" then the text above that is the settings
and the text below is the list of opted in users.
If it doesn't contain "---" then the settings are empty and the rest is the users.
"""
rollout_state_parts = rollout_state.split("---")
if len(rollout_state_parts) >= 2:
return rollout_state_parts[0], rollout_state_parts[1]
else:
return "", rollout_state
class UserOptins(Dict[str, List[str]]):
"""
Dictionary of users with a list of features they have opted into
"""
def parse_user_opt_in_from_text(user_optin_text: str) -> UserOptins:
"""
Parse the user opt-in text into a key value pair of username and the list of features they have opted into
Users are GitHub usernames with the @ prefix. Each user is also a comma-separated list of features/experiments to enable.
- Example line: "@User1,lf,split_build"
- A "#" prefix indicates the user is opted out of all experiments
"""
optins = UserOptins()
for user in user_optin_text.split("\n"):
user = user.strip("\r\n\t -")
if not user or not user.startswith("@"):
# Not a valid user. Skip
continue
if user:
usr_name = user.split(",")[0].strip("@")
optins[usr_name] = [exp.strip(" ") for exp in user.split(",")[1:]]
return optins
def parse_settings_from_text(settings_text: str) -> Settings:
"""
Parse the experiments from the issue body into a list of ExperimentSettings
"""
try:
if settings_text:
# Escape the backtick as well so that we can have the settings in a code block on the GH issue
# for easy reading
# Note: Using ascii for the backtick so that the cat step in _runner-determinator.yml doesn't choke on
# the backtick character in shell commands.
backtick = chr(96) # backtick character
settings_text = settings_text.strip(f"\r\n\t{backtick} ")
settings = load_yaml(settings_text)
# For now we just load experiments. We can expand this if/when we add more settings
experiments = {}
for exp_name, exp_settings in settings.get(SETTING_EXPERIMENTS).items():
valid_settings = {}
for setting in exp_settings:
if setting not in Experiment._fields:
log.warning(
f"Unexpected setting in experiment: {setting} = {exp_settings[setting]}"
)
else:
valid_settings[setting] = exp_settings[setting]
experiments[exp_name] = Experiment(**valid_settings)
return Settings(experiments)
except Exception:
log.exception("Failed to parse settings")
return Settings()
def parse_settings(rollout_state: str) -> Settings:
"""
Parse settings, if any, from the rollout state.
If the issue body contains "---" then the text above that is the settings
and the text below is the list of opted in users.
If it doesn't contain "---" then the settings are empty and the default values are used.
"""
settings_text, _ = extract_settings_user_opt_in_from_text(rollout_state)
return parse_settings_from_text(settings_text)
def parse_users(rollout_state: str) -> UserOptins:
"""
Parse users from the rollout state.
"""
_, users_text = extract_settings_user_opt_in_from_text(rollout_state)
return parse_user_opt_in_from_text(users_text)
def is_user_opted_in(user: str, user_optins: UserOptins, experiment_name: str) -> bool:
"""
Check if a user is opted into an experiment
"""
return experiment_name in user_optins.get(user, [])
def get_runner_prefix(
rollout_state: str, workflow_requestors: Iterable[str], is_canary: bool = False
) -> str:
"""
Used to dynamically opt in jobs to specific runner-type variants.
settings = parse_settings(rollout_state)
user_optins = parse_users(rollout_state)
Returns:
The runner-type's variant name if the user has opted in to the feature, otherwise returns an empty string.
This variant name is prefixed to the runner-type in the label.
"""
try:
userlist = {u.lstrip("#").strip("\n\t@ ") for u in rollout_state.split()}
all_opted_in_users = set()
for user in userlist:
for i in user.split(","):
if i == feature:
all_opted_in_users.add(user.split(",")[0])
opted_in_requestors = {
usr for usr in workflow_requestors if usr in all_opted_in_users
}
fleet_prefix = ""
prefixes = []
for experiment_name, experiment_settings in settings.experiments.items():
enabled = False
if opted_in_requestors:
# Is any workflow_requestor opted in to this experiment?
opted_in_users = [
requestor
for requestor in workflow_requestors
if is_user_opted_in(requestor, user_optins, experiment_name)
]
if opted_in_users:
log.info(
f"Feature {feature} is enabled for {', '.join(opted_in_requestors)}. Using feature {feature}."
f"{', '.join(opted_in_users)} have opted into experiment {experiment_name}."
)
return feature
else:
log.info(
f"Feature {feature} is disabled for {', '.join(workflow_requestors)}. Using fallback \"{fallback}\"."
)
return fallback
enabled = True
elif experiment_settings.rollout_perc:
# If no user is opted in, then we randomly enable the experiment based on the rollout percentage
if random.uniform(0, 100) <= experiment_settings.rollout_perc:
log.info(
f"Based on rollout percentage of {experiment_settings.rollout_perc}%, enabling experiment {experiment_name}."
)
enabled = True
except Exception as e:
if enabled:
label = experiment_name
if experiment_name == LF_FLEET_EXPERIMENT:
# We give some special treatment to the "lf" experiment since determines the fleet we use
# - If it's enabled, then we always list it's prefix first
# - If we're in the canary branch, then we append ".c" to the lf prefix
if is_canary:
label += CANARY_FLEET_SUFFIX
fleet_prefix = label
else:
prefixes.append(label)
if len(prefixes) > 1:
log.error(
f'Failed to determine if user has opted-in to feature {feature}. Using fallback "{fallback}". Exception: {e}'
f"Only a fleet and one other experiment can be enabled for a job at any time. Enabling {prefixes[0]} and ignoring the rest, which are {', '.join(prefixes[1:])}"
)
return fallback
prefixes = prefixes[:1]
# Fleet always comes first
if fleet_prefix:
prefixes.insert(0, fleet_prefix)
return ".".join(prefixes) + "." if prefixes else ""
def get_rollout_state_from_issue(github_token: str, repo: str, issue_num: int) -> str:
@ -268,9 +408,10 @@ def main() -> None:
args = parse_args()
if args.github_ref_type == "branch" and is_exception_branch(args.github_branch):
log.info(f"Exception branch: '{args.github_branch}', using meta runners")
label_type = WORKFLOW_LABEL_META
runner_ami = RUNNER_AMI_LEGACY
log.info(
f"Exception branch: '{args.github_branch}', using Meta runners and no experiments."
)
runner_label_prefix = DEFAULT_LABEL_PREFIX
else:
try:
rollout_state = get_rollout_state_from_issue(
@ -285,35 +426,18 @@ def main() -> None:
args.github_branch,
)
label_type = get_fleet(
rollout_state,
(
args.github_issue_owner,
username,
),
)
runner_ami = get_optin_feature(
rollout_state=rollout_state,
workflow_requestors=(
args.github_issue_owner,
username,
),
feature=RUNNER_AMI_AMZ2023,
fallback=RUNNER_AMI_LEGACY,
is_canary = args.github_repo == "pytorch/pytorch-canary"
runner_label_prefix = get_runner_prefix(
rollout_state, (args.github_issue_owner, username), is_canary
)
except Exception as e:
log.error(
f"Failed to get issue. Falling back to meta runners. Exception: {e}"
f"Failed to get issue. Defaulting to Meta runners and no experiments. Exception: {e}"
)
label_type = WORKFLOW_LABEL_META
runner_ami = RUNNER_AMI_LEGACY
# For Canary builds use canary runners
if args.github_repo == "pytorch/pytorch-canary" and label_type == WORKFLOW_LABEL_LF:
label_type = WORKFLOW_LABEL_LF_CANARY
set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, label_type)
set_github_output(GH_OUTPUT_KEY_AMI, runner_ami)
set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, runner_label_prefix)
if __name__ == "__main__":

View File

@ -51,6 +51,8 @@ def main() -> None:
for platform_image in platform_images: # type: ignore[attr-defined]
for arch in platform_image.keys(): # type: ignore[attr-defined]
if arch == "cpu-s390x":
continue
tag_image(
platform_image[arch], # type: ignore[index]
default_tag,

View File

@ -0,0 +1,237 @@
from unittest import main, TestCase
from unittest.mock import Mock, patch
import runner_determinator as rd
class TestRunnerDeterminatorIssueParser(TestCase):
def test_parse_settings(self) -> None:
settings_text = """
experiments:
lf:
rollout_perc: 25
otherExp:
rollout_perc: 0
---
Users:
@User1,lf
@User2,lf,otherExp
"""
settings = rd.parse_settings(settings_text)
self.assertTupleEqual(
rd.Experiment(rollout_perc=25),
settings.experiments["lf"],
"lf settings not parsed correctly",
)
self.assertTupleEqual(
rd.Experiment(rollout_perc=0),
settings.experiments["otherExp"],
"otherExp settings not parsed correctly",
)
def test_parse_settings_in_code_block(self) -> None:
settings_text = """
```
experiments:
lf:
rollout_perc: 25
otherExp:
rollout_perc: 0
```
---
Users:
@User1,lf
@User2,lf,otherExp
"""
settings = rd.parse_settings(settings_text)
self.assertTupleEqual(
rd.Experiment(rollout_perc=25),
settings.experiments["lf"],
"lf settings not parsed correctly",
)
self.assertTupleEqual(
rd.Experiment(rollout_perc=0),
settings.experiments["otherExp"],
"otherExp settings not parsed correctly",
)
def test_parse_users(self) -> None:
settings_text = """
experiments:
lf:
rollout_perc: 0
otherExp:
rollout_perc: 0
---
Users:
@User1,lf
@User2,lf,otherExp
"""
users = rd.parse_users(settings_text)
self.assertDictEqual(
{"User1": ["lf"], "User2": ["lf", "otherExp"]},
users,
"Users not parsed correctly",
)
def test_parse_users_without_settings(self) -> None:
settings_text = """
@User1,lf
@User2,lf,otherExp
"""
users = rd.parse_users(settings_text)
self.assertDictEqual(
{"User1": ["lf"], "User2": ["lf", "otherExp"]},
users,
"Users not parsed correctly",
)
class TestRunnerDeterminatorGetRunnerPrefix(TestCase):
def test_opted_in_user(self) -> None:
settings_text = """
experiments:
lf:
rollout_perc: 0
otherExp:
rollout_perc: 0
---
Users:
@User1,lf
@User2,lf,otherExp
"""
prefix = rd.get_runner_prefix(settings_text, ["User1"])
self.assertEqual("lf.", prefix, "Runner prefix not correct for User1")
def test_opted_in_user_two_experiments(self) -> None:
settings_text = """
experiments:
lf:
rollout_perc: 0
otherExp:
rollout_perc: 0
---
Users:
@User1,lf
@User2,lf,otherExp
"""
prefix = rd.get_runner_prefix(settings_text, ["User2"])
self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for User2")
@patch("random.uniform", return_value=50)
def test_opted_out_user(self, mock_uniform: Mock) -> None:
settings_text = """
experiments:
lf:
rollout_perc: 25
otherExp:
rollout_perc: 25
---
Users:
@User1,lf
@User2,lf,otherExp
"""
prefix = rd.get_runner_prefix(settings_text, ["User3"])
self.assertEqual("", prefix, "Runner prefix not correct for user")
@patch("random.uniform", return_value=10)
def test_opted_out_user_was_pulled_in_by_rollout(self, mock_uniform: Mock) -> None:
settings_text = """
experiments:
lf:
rollout_perc: 25
otherExp:
rollout_perc: 25
---
Users:
@User1,lf
@User2,lf,otherExp
"""
# User3 is opted out, but is pulled into both experiments by the 10% rollout
prefix = rd.get_runner_prefix(settings_text, ["User3"])
self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user")
def test_lf_prefix_always_comes_first(self) -> None:
settings_text = """
experiments:
otherExp:
rollout_perc: 0
lf:
rollout_perc: 0
---
Users:
@User1,lf
@User2,otherExp,lf
"""
prefix = rd.get_runner_prefix(settings_text, ["User2"])
self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user")
def test_ignores_commented_users(self) -> None:
settings_text = """
experiments:
lf:
rollout_perc: 0
otherExp:
rollout_perc: 0
---
Users:
#@User1,lf
@User2,lf,otherExp
"""
prefix = rd.get_runner_prefix(settings_text, ["User1"])
self.assertEqual("", prefix, "Runner prefix not correct for user")
def test_ignores_extra_experiments(self) -> None:
settings_text = """
experiments:
lf:
rollout_perc: 0
otherExp:
rollout_perc: 0
foo:
rollout_perc: 0
---
Users:
@User1,lf,otherExp,foo
"""
prefix = rd.get_runner_prefix(settings_text, ["User1"])
self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user")
if __name__ == "__main__":
main()

View File

@ -45,7 +45,7 @@
{%- if is_windows %}
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
{%- endif %}
{%- else %}

View File

@ -62,49 +62,94 @@ jobs:
"""
This runner determinator is used to determine which set of runners to run a
GitHub job on. It uses the first comment of a GitHub issue (by default
https://github.com/pytorch/test-infra/issues/5132) as a user list to determine
which users will get their jobs to run on experimental runners. This user list
is also a comma separated list of additional features or experiments which the
user could be opted in to.
https://github.com/pytorch/test-infra/issues/5132) to define the configuration
of which runners should be used to run which job.
The configuration has two parts, the settings and a list of opted-in users,
separated by a line containing "---". If the line is not present, the
settings are considered to be empty with only the second part, the user
list, defined.
The first part is a YAML block that defines the rollout settings. This can be
used to define any settings that are needed to determine which runners to use.
It's fields are defined by the RolloutSettings class below.
The second part is a list of users who are explicitly opted in to the LF fleet.
The user list is also a comma separated list of additional features or
experiments which the user could be opted in to.
The user list has the following rules:
- Users are GitHub usernames with the @ prefix
- If the first line is a "*" then all users will use the new runners
- If the first line is a "!" then all users will use the old runners
- Users are GitHub usernames, which must start with the @ prefix
- Each user is also a comma-separated list of features/experiments to enable
- A "#" prefix indicates the user is opted out of the new runners but is opting
into features/experiments.
- A "#" prefix opts the user out of all experiments
Example user list:
Example config:
# A list of experiments that can be opted into.
# This defines the behavior they'll induce when opted into.
# Expected syntax is:
# [experiment_name]: # Name of the experiment. Also used for the label prefix.
# rollout_perc: [int] # % of workflows to run with this experiment when users are not opted in.
@User1
@User2,amz2023
#@UserOptOutOfNewRunner,amz2023
experiments:
lf:
rollout_percent: 25
---
# Opt-ins:
# Users can opt into the LF fleet by adding their GitHub username to this list
# and specifying experiments to enable in a comma-separated list.
# Experiments should be from the above list.
@User1,lf,split_build
@User2,lf
@User3,split_build
"""
import logging
import os
import random
from argparse import ArgumentParser
from logging import LogRecord
from typing import Any, Iterable
from typing import Any, Dict, Iterable, List, NamedTuple, Tuple
import yaml
from github import Auth, Github
from github.Issue import Issue
WORKFLOW_LABEL_META = "" # use meta runners
DEFAULT_LABEL_PREFIX = "" # use meta runners
WORKFLOW_LABEL_LF = "lf." # use runners from the linux foundation
WORKFLOW_LABEL_LF_CANARY = "lf.c." # use canary runners from the linux foundation
RUNNER_AMI_LEGACY = ""
RUNNER_AMI_AMZ2023 = "amz2023"
GITHUB_OUTPUT = os.getenv("GITHUB_OUTPUT", "")
GH_OUTPUT_KEY_AMI = "runner-ami"
GH_OUTPUT_KEY_LABEL_TYPE = "label-type"
SETTING_EXPERIMENTS = "experiments"
LF_FLEET_EXPERIMENT = "lf"
CANARY_FLEET_SUFFIX = ".c"
class Experiment(NamedTuple):
rollout_perc: float = (
0 # Percentage of workflows to experiment on when user is not opted-in.
)
# Add more fields as needed
class Settings(NamedTuple):
"""
Settings for the experiments that can be opted into.
"""
experiments: Dict[str, Experiment] = {}
class ColorFormatter(logging.Formatter):
"""Color codes the log messages based on the log level"""
@ -231,85 +276,180 @@ jobs:
return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"}
def get_fleet(rollout_state: str, workflow_requestors: Iterable[str]) -> str:
"""
Determines if the job should run on the LF fleet or the Meta fleet
Returns:
The appropriate label prefix for the runner, corresponding to the fleet to use.
This gets prefixed to the very start of the runner label.
"""
def load_yaml(yaml_text: str) -> Any:
try:
if rollout_state[0] == "!":
log.info("LF Workflows are disabled for everyone. Using meta runners.")
return WORKFLOW_LABEL_META
elif rollout_state[0] == "*":
log.info("LF Workflows are enabled for everyone. Using LF runners.")
return WORKFLOW_LABEL_LF
else:
all_opted_in_users = {
usr_raw.strip("\n\t@ ").split(",")[0]
for usr_raw in rollout_state.split()
}
opted_in_requestors = {
usr for usr in workflow_requestors if usr in all_opted_in_users
}
if opted_in_requestors:
log.info(
f"LF Workflows are enabled for {', '.join(opted_in_requestors)}. Using LF runners."
)
return WORKFLOW_LABEL_LF
else:
log.info(
f"LF Workflows are disabled for {', '.join(workflow_requestors)}. Using meta runners."
)
return WORKFLOW_LABEL_META
except Exception as e:
log.error(
f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}"
)
return WORKFLOW_LABEL_META
data = yaml.safe_load(yaml_text)
return data
except yaml.YAMLError as exc:
log.exception("Error loading YAML")
raise
def get_optin_feature(
rollout_state: str, workflow_requestors: Iterable[str], feature: str, fallback: str
def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str]:
"""
Extracts the text with settings, if any, and the opted in users from the rollout state.
If the issue body contains "---" then the text above that is the settings
and the text below is the list of opted in users.
If it doesn't contain "---" then the settings are empty and the rest is the users.
"""
rollout_state_parts = rollout_state.split("---")
if len(rollout_state_parts) >= 2:
return rollout_state_parts[0], rollout_state_parts[1]
else:
return "", rollout_state
class UserOptins(Dict[str, List[str]]):
"""
Dictionary of users with a list of features they have opted into
"""
def parse_user_opt_in_from_text(user_optin_text: str) -> UserOptins:
"""
Parse the user opt-in text into a key value pair of username and the list of features they have opted into
Users are GitHub usernames with the @ prefix. Each user is also a comma-separated list of features/experiments to enable.
- Example line: "@User1,lf,split_build"
- A "#" prefix indicates the user is opted out of all experiments
"""
optins = UserOptins()
for user in user_optin_text.split("\n"):
user = user.strip("\r\n\t -")
if not user or not user.startswith("@"):
# Not a valid user. Skip
continue
if user:
usr_name = user.split(",")[0].strip("@")
optins[usr_name] = [exp.strip(" ") for exp in user.split(",")[1:]]
return optins
def parse_settings_from_text(settings_text: str) -> Settings:
"""
Parse the experiments from the issue body into a list of ExperimentSettings
"""
try:
if settings_text:
# Escape the backtick as well so that we can have the settings in a code block on the GH issue
# for easy reading
# Note: Using ascii for the backtick so that the cat step in _runner-determinator.yml doesn't choke on
# the backtick character in shell commands.
backtick = chr(96) # backtick character
settings_text = settings_text.strip(f"\r\n\t{backtick} ")
settings = load_yaml(settings_text)
# For now we just load experiments. We can expand this if/when we add more settings
experiments = {}
for exp_name, exp_settings in settings.get(SETTING_EXPERIMENTS).items():
valid_settings = {}
for setting in exp_settings:
if setting not in Experiment._fields:
log.warning(
f"Unexpected setting in experiment: {setting} = {exp_settings[setting]}"
)
else:
valid_settings[setting] = exp_settings[setting]
experiments[exp_name] = Experiment(**valid_settings)
return Settings(experiments)
except Exception:
log.exception("Failed to parse settings")
return Settings()
def parse_settings(rollout_state: str) -> Settings:
"""
Parse settings, if any, from the rollout state.
If the issue body contains "---" then the text above that is the settings
and the text below is the list of opted in users.
If it doesn't contain "---" then the settings are empty and the default values are used.
"""
settings_text, _ = extract_settings_user_opt_in_from_text(rollout_state)
return parse_settings_from_text(settings_text)
def parse_users(rollout_state: str) -> UserOptins:
"""
Parse users from the rollout state.
"""
_, users_text = extract_settings_user_opt_in_from_text(rollout_state)
return parse_user_opt_in_from_text(users_text)
def is_user_opted_in(user: str, user_optins: UserOptins, experiment_name: str) -> bool:
"""
Check if a user is opted into an experiment
"""
return experiment_name in user_optins.get(user, [])
def get_runner_prefix(
rollout_state: str, workflow_requestors: Iterable[str], is_canary: bool = False
) -> str:
"""
Used to dynamically opt in jobs to specific runner-type variants.
settings = parse_settings(rollout_state)
user_optins = parse_users(rollout_state)
Returns:
The runner-type's variant name if the user has opted in to the feature, otherwise returns an empty string.
This variant name is prefixed to the runner-type in the label.
"""
try:
userlist = {u.lstrip("#").strip("\n\t@ ") for u in rollout_state.split()}
all_opted_in_users = set()
for user in userlist:
for i in user.split(","):
if i == feature:
all_opted_in_users.add(user.split(",")[0])
opted_in_requestors = {
usr for usr in workflow_requestors if usr in all_opted_in_users
}
fleet_prefix = ""
prefixes = []
for experiment_name, experiment_settings in settings.experiments.items():
enabled = False
if opted_in_requestors:
# Is any workflow_requestor opted in to this experiment?
opted_in_users = [
requestor
for requestor in workflow_requestors
if is_user_opted_in(requestor, user_optins, experiment_name)
]
if opted_in_users:
log.info(
f"Feature {feature} is enabled for {', '.join(opted_in_requestors)}. Using feature {feature}."
f"{', '.join(opted_in_users)} have opted into experiment {experiment_name}."
)
return feature
else:
log.info(
f"Feature {feature} is disabled for {', '.join(workflow_requestors)}. Using fallback \"{fallback}\"."
)
return fallback
enabled = True
elif experiment_settings.rollout_perc:
# If no user is opted in, then we randomly enable the experiment based on the rollout percentage
if random.uniform(0, 100) <= experiment_settings.rollout_perc:
log.info(
f"Based on rollout percentage of {experiment_settings.rollout_perc}%, enabling experiment {experiment_name}."
)
enabled = True
except Exception as e:
if enabled:
label = experiment_name
if experiment_name == LF_FLEET_EXPERIMENT:
# We give some special treatment to the "lf" experiment since determines the fleet we use
# - If it's enabled, then we always list it's prefix first
# - If we're in the canary branch, then we append ".c" to the lf prefix
if is_canary:
label += CANARY_FLEET_SUFFIX
fleet_prefix = label
else:
prefixes.append(label)
if len(prefixes) > 1:
log.error(
f'Failed to determine if user has opted-in to feature {feature}. Using fallback "{fallback}". Exception: {e}'
f"Only a fleet and one other experiment can be enabled for a job at any time. Enabling {prefixes[0]} and ignoring the rest, which are {', '.join(prefixes[1:])}"
)
return fallback
prefixes = prefixes[:1]
# Fleet always comes first
if fleet_prefix:
prefixes.insert(0, fleet_prefix)
return ".".join(prefixes) + "." if prefixes else ""
def get_rollout_state_from_issue(github_token: str, repo: str, issue_num: int) -> str:
@ -327,9 +467,10 @@ jobs:
args = parse_args()
if args.github_ref_type == "branch" and is_exception_branch(args.github_branch):
log.info(f"Exception branch: '{args.github_branch}', using meta runners")
label_type = WORKFLOW_LABEL_META
runner_ami = RUNNER_AMI_LEGACY
log.info(
f"Exception branch: '{args.github_branch}', using Meta runners and no experiments."
)
runner_label_prefix = DEFAULT_LABEL_PREFIX
else:
try:
rollout_state = get_rollout_state_from_issue(
@ -344,35 +485,18 @@ jobs:
args.github_branch,
)
label_type = get_fleet(
rollout_state,
(
args.github_issue_owner,
username,
),
)
runner_ami = get_optin_feature(
rollout_state=rollout_state,
workflow_requestors=(
args.github_issue_owner,
username,
),
feature=RUNNER_AMI_AMZ2023,
fallback=RUNNER_AMI_LEGACY,
is_canary = args.github_repo == "pytorch/pytorch-canary"
runner_label_prefix = get_runner_prefix(
rollout_state, (args.github_issue_owner, username), is_canary
)
except Exception as e:
log.error(
f"Failed to get issue. Falling back to meta runners. Exception: {e}"
f"Failed to get issue. Defaulting to Meta runners and no experiments. Exception: {e}"
)
label_type = WORKFLOW_LABEL_META
runner_ami = RUNNER_AMI_LEGACY
# For Canary builds use canary runners
if args.github_repo == "pytorch/pytorch-canary" and label_type == WORKFLOW_LABEL_LF:
label_type = WORKFLOW_LABEL_LF_CANARY
set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, label_type)
set_github_output(GH_OUTPUT_KEY_AMI, runner_ami)
set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, runner_label_prefix)
if __name__ == "__main__":

View File

@ -29,9 +29,19 @@ concurrency:
cancel-in-progress: true
jobs:
get-label-type:
name: get-label-type
uses: ./.github/workflows/_runner-determinator.yml
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
build-docker-cuda:
environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }}
runs-on: linux.9xlarge.ephemeral
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral"
strategy:
matrix:
cuda_version: ["12.4", "12.1", "11.8"]
@ -66,7 +76,8 @@ jobs:
.ci/docker/libtorch/build.sh libtorch-cxx11-builder:cuda${{matrix.cuda_version}}
build-docker-rocm:
environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }}
runs-on: linux.9xlarge.ephemeral
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral"
strategy:
matrix:
rocm_version: ["6.1", "6.2"]
@ -101,7 +112,8 @@ jobs:
.ci/docker/libtorch/build.sh libtorch-cxx11-builder:rocm${{matrix.rocm_version}}
build-docker-cpu:
environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }}
runs-on: linux.9xlarge.ephemeral
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral"
steps:
- name: Checkout PyTorch
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main

View File

@ -33,9 +33,19 @@ concurrency:
cancel-in-progress: true
jobs:
get-label-type:
name: get-label-type
uses: ./.github/workflows/_runner-determinator.yml
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
build-docker-cuda:
environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }}
runs-on: am2.linux.9xlarge.ephemeral
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}am2.linux.9xlarge.ephemeral"
strategy:
matrix:
cuda_version: ["12.4", "12.1", "11.8"]
@ -73,7 +83,8 @@ jobs:
# NOTE: manylinux_2_28 are still experimental, see https://github.com/pytorch/pytorch/issues/123649
build-docker-cuda-manylinux_2_28:
environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }}
runs-on: linux.9xlarge.ephemeral
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral"
strategy:
matrix:
cuda_version: ["12.4", "12.1", "11.8"]
@ -110,7 +121,8 @@ jobs:
.ci/docker/manywheel/build.sh manylinux2_28-builder:cuda${{matrix.cuda_version}}
build-docker-cuda-aarch64:
environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }}
runs-on: linux.arm64.2xlarge.ephemeral
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.2xlarge.ephemeral"
strategy:
matrix:
cuda_version: ["12.4"]
@ -143,7 +155,8 @@ jobs:
.ci/docker/manywheel/build.sh manylinuxaarch64-builder:cuda${{matrix.cuda_version}}
build-docker-rocm:
environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }}
runs-on: am2.linux.9xlarge.ephemeral
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}am2.linux.9xlarge.ephemeral"
strategy:
matrix:
rocm_version: ["6.1", "6.2"]
@ -178,7 +191,8 @@ jobs:
.ci/docker/manywheel/build.sh manylinux-builder:rocm${{matrix.rocm_version}}
build-docker-cpu:
environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }}
runs-on: am2.linux.9xlarge.ephemeral
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}am2.linux.9xlarge.ephemeral"
steps:
- name: Checkout PyTorch
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
@ -207,7 +221,8 @@ jobs:
.ci/docker/manywheel/build.sh manylinux-builder:cpu
build-docker-cpu-manylinux_2_28:
environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }}
runs-on: linux.9xlarge.ephemeral
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral"
env:
GPU_ARCH_TYPE: cpu-manylinux_2_28
steps:
@ -238,7 +253,8 @@ jobs:
.ci/docker/manywheel/build.sh manylinux2_28-builder:cpu
build-docker-cpu-aarch64:
environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }}
runs-on: linux.arm64.2xlarge.ephemeral
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.2xlarge.ephemeral"
env:
GPU_ARCH_TYPE: cpu-aarch64
steps:
@ -269,7 +285,8 @@ jobs:
.ci/docker/manywheel/build.sh manylinuxaarch64-builder:cpu-aarch64
build-docker-cpu-aarch64-2_28:
environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }}
runs-on: linux.arm64.2xlarge.ephemeral
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.2xlarge.ephemeral"
env:
GPU_ARCH_TYPE: cpu-aarch64-2_28
steps:
@ -303,7 +320,8 @@ jobs:
.ci/docker/manywheel/build.sh manylinux2_28_aarch64-builder:cpu-aarch64
build-docker-cpu-cxx11-abi:
environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }}
runs-on: linux.9xlarge.ephemeral
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral"
env:
GPU_ARCH_TYPE: cpu-cxx11-abi
steps:
@ -334,7 +352,8 @@ jobs:
.ci/docker/manywheel/build.sh manylinuxcxx11-abi-builder:cpu-cxx11-abi
build-docker-xpu:
environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }}
runs-on: linux.9xlarge.ephemeral
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral"
env:
GPU_ARCH_TYPE: xpu
steps:

View File

@ -13,7 +13,6 @@ on:
- .github/scripts/build_triton_wheel.py
- .github/ci_commit_pins/triton.txt
- .ci/docker/ci_commit_pins/triton.txt
- .ci/docker/ci_commit_pins/triton-rocm.txt
- .ci/docker/ci_commit_pins/triton-xpu.txt
pull_request:
paths:
@ -21,7 +20,6 @@ on:
- .github/scripts/build_triton_wheel.py
- .github/ci_commit_pins/triton.txt
- .ci/docker/ci_commit_pins/triton.txt
- .ci/docker/ci_commit_pins/triton-rocm.txt
- .ci/docker/ci_commit_pins/triton-xpu.txt
concurrency:
@ -29,9 +27,19 @@ concurrency:
cancel-in-progress: true
jobs:
get-label-type:
name: get-label-type
uses: ./.github/workflows/_runner-determinator.yml
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
build-wheel:
name: "Build Triton Wheel"
runs-on: [self-hosted, linux.2xlarge]
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge"
strategy:
fail-fast: false
matrix:
@ -120,7 +128,7 @@ jobs:
fi
docker exec -t "${container_name}" chown -R 1000.1000 /artifacts
- uses: actions/upload-artifact@v3
- uses: actions/upload-artifact@v4.4.0
with:
name: pytorch-triton-wheel-${{ matrix.py_vers }}-${{ matrix.device }}
if-no-files-found: error
@ -201,7 +209,8 @@ jobs:
build-conda:
name: "Build Triton Conda"
runs-on: [self-hosted, linux.2xlarge]
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge"
strategy:
fail-fast: false
matrix:
@ -253,7 +262,7 @@ jobs:
docker exec -t "${container_name}" python /pytorch/.github/scripts/build_triton_wheel.py --build-conda --py-version="${PY_VERS}" $RELEASE
docker exec -t "${container_name}" chown -R 1000.1000 /artifacts
- uses: actions/upload-artifact@v3
- uses: actions/upload-artifact@v4.4.0
with:
name: pytorch-triton-conda-${{ matrix.py_vers }}
if-no-files-found: error

View File

@ -16,6 +16,15 @@ on:
paths: [.github/workflows/create_release.yml]
jobs:
get-label-type:
name: get-label-type
uses: ./.github/workflows/_runner-determinator.yml
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
release:
if: ${{ github.repository == 'pytorch/pytorch' }}
name: Create Release
@ -63,7 +72,7 @@ jobs:
files: ${{env.PT_RELEASE_FILE}}
- name: Upload source distribution to GHA artifacts for release tags
if: ${{ github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') && contains(github.ref, 'rc') }}
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v4.4.0
with:
name: ${{ env.PT_RELEASE_FILE }}
path: ${{ env.PT_RELEASE_FILE }}
@ -73,12 +82,14 @@ jobs:
upload_source_code_to_s3:
if: ${{ github.repository == 'pytorch/pytorch' && github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') && contains(github.ref, 'rc') }}
runs-on: linux.2xlarge
runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge"
environment: sourcecode-upload
name: Upload source code to S3 for release tags
permissions:
id-token: write
needs: release
needs:
- get-label-type
- release
steps:
- uses: actions/download-artifact@v4.1.7
with:

View File

@ -30,8 +30,18 @@ env:
permissions: read-all
jobs:
get-label-type:
name: get-label-type
uses: ./.github/workflows/_runner-determinator.yml
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
docker-build:
environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }}
needs: get-label-type
timeout-minutes: 240
strategy:
fail-fast: false
@ -68,7 +78,7 @@ jobs:
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
runner: linux.arm64.m7g.4xlarge
timeout-minutes: 600
runs-on: [self-hosted, "${{ matrix.runner }}"]
runs-on: "${{ needs.get-label-type.outputs.label-type }}${{ matrix.runner }}"
env:
DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/${{ matrix.docker-image-name }}
steps:

View File

@ -34,9 +34,19 @@ env:
permissions: read-all
jobs:
get-label-type:
name: get-label-type
uses: ./.github/workflows/_runner-determinator.yml
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
generate-matrix:
if: github.repository_owner == 'pytorch'
runs-on: [self-hosted, linux.large]
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.large"
outputs:
matrix: ${{ steps.generate-matrix.outputs.matrix }}
steps:
@ -54,10 +64,12 @@ jobs:
build:
if: ${{ github.repository == 'pytorch/pytorch' }}
runs-on: [self-hosted, linux.2xlarge]
runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge"
environment: ${{ (github.ref == 'refs/heads/nightly' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }}
timeout-minutes: 240
needs: generate-matrix
needs:
- generate-matrix
- get-label-type
strategy:
matrix: ${{ fromJson(needs.generate-matrix.outputs.matrix) }}
fail-fast: false

View File

@ -58,6 +58,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main
use_split_build: False
DESIRED_PYTHON: "3.9"
runs_on: linux.arm64.m7g.4xlarge.ephemeral
ALPINE_IMAGE: "arm64v8/alpine"
@ -81,6 +82,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main
use_split_build: False
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cpu-aarch64
build_environment: linux-aarch64-binary-manywheel
@ -103,6 +105,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main
use_split_build: False
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cpu-aarch64
secrets:
@ -125,6 +128,7 @@ jobs:
GPU_ARCH_TYPE: cuda-aarch64
DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main
DESIRED_DEVTOOLSET: cxx11-abi
use_split_build: False
DESIRED_PYTHON: "3.9"
runs_on: linux.arm64.m7g.4xlarge.ephemeral
ALPINE_IMAGE: "arm64v8/alpine"
@ -149,6 +153,7 @@ jobs:
GPU_ARCH_TYPE: cuda-aarch64
DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main
DESIRED_DEVTOOLSET: cxx11-abi
use_split_build: False
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cuda-aarch64
secrets:
@ -170,6 +175,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main
use_split_build: False
DESIRED_PYTHON: "3.10"
runs_on: linux.arm64.m7g.4xlarge.ephemeral
ALPINE_IMAGE: "arm64v8/alpine"
@ -193,6 +199,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main
use_split_build: False
DESIRED_PYTHON: "3.10"
build_name: manywheel-py3_10-cpu-aarch64
build_environment: linux-aarch64-binary-manywheel
@ -215,6 +222,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main
use_split_build: False
DESIRED_PYTHON: "3.10"
build_name: manywheel-py3_10-cpu-aarch64
secrets:
@ -237,6 +245,7 @@ jobs:
GPU_ARCH_TYPE: cuda-aarch64
DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main
DESIRED_DEVTOOLSET: cxx11-abi
use_split_build: False
DESIRED_PYTHON: "3.10"
runs_on: linux.arm64.m7g.4xlarge.ephemeral
ALPINE_IMAGE: "arm64v8/alpine"
@ -261,6 +270,7 @@ jobs:
GPU_ARCH_TYPE: cuda-aarch64
DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main
DESIRED_DEVTOOLSET: cxx11-abi
use_split_build: False
DESIRED_PYTHON: "3.10"
build_name: manywheel-py3_10-cuda-aarch64
secrets:
@ -282,6 +292,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main
use_split_build: False
DESIRED_PYTHON: "3.11"
runs_on: linux.arm64.m7g.4xlarge.ephemeral
ALPINE_IMAGE: "arm64v8/alpine"
@ -305,6 +316,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main
use_split_build: False
DESIRED_PYTHON: "3.11"
build_name: manywheel-py3_11-cpu-aarch64
build_environment: linux-aarch64-binary-manywheel
@ -327,6 +339,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main
use_split_build: False
DESIRED_PYTHON: "3.11"
build_name: manywheel-py3_11-cpu-aarch64
secrets:
@ -349,6 +362,7 @@ jobs:
GPU_ARCH_TYPE: cuda-aarch64
DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main
DESIRED_DEVTOOLSET: cxx11-abi
use_split_build: False
DESIRED_PYTHON: "3.11"
runs_on: linux.arm64.m7g.4xlarge.ephemeral
ALPINE_IMAGE: "arm64v8/alpine"
@ -373,6 +387,7 @@ jobs:
GPU_ARCH_TYPE: cuda-aarch64
DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main
DESIRED_DEVTOOLSET: cxx11-abi
use_split_build: False
DESIRED_PYTHON: "3.11"
build_name: manywheel-py3_11-cuda-aarch64
secrets:
@ -394,6 +409,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main
use_split_build: False
DESIRED_PYTHON: "3.12"
runs_on: linux.arm64.m7g.4xlarge.ephemeral
ALPINE_IMAGE: "arm64v8/alpine"
@ -417,6 +433,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main
use_split_build: False
DESIRED_PYTHON: "3.12"
build_name: manywheel-py3_12-cpu-aarch64
build_environment: linux-aarch64-binary-manywheel
@ -439,6 +456,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main
use_split_build: False
DESIRED_PYTHON: "3.12"
build_name: manywheel-py3_12-cpu-aarch64
secrets:
@ -461,6 +479,7 @@ jobs:
GPU_ARCH_TYPE: cuda-aarch64
DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main
DESIRED_DEVTOOLSET: cxx11-abi
use_split_build: False
DESIRED_PYTHON: "3.12"
runs_on: linux.arm64.m7g.4xlarge.ephemeral
ALPINE_IMAGE: "arm64v8/alpine"
@ -485,6 +504,7 @@ jobs:
GPU_ARCH_TYPE: cuda-aarch64
DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main
DESIRED_DEVTOOLSET: cxx11-abi
use_split_build: False
DESIRED_PYTHON: "3.12"
build_name: manywheel-py3_12-cuda-aarch64
secrets:

View File

@ -54,6 +54,7 @@ jobs:
GPU_ARCH_VERSION: 11.8
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main
use_split_build: False
DESIRED_PYTHON: "3.9"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_9-cuda11_8
@ -77,6 +78,7 @@ jobs:
GPU_ARCH_VERSION: 11.8
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main
use_split_build: False
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cuda11_8
build_environment: linux-binary-manywheel
@ -85,53 +87,6 @@ jobs:
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_9-cuda11_8-split-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
needs: get-label-type
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu118
GPU_ARCH_VERSION: 11.8
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main
use_split_build: True
DESIRED_PYTHON: "3.9"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_9-cuda11_8-split
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_9-cuda11_8-split-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- manywheel-py3_9-cuda11_8-split-build
- get-label-type
uses: ./.github/workflows/_binary-test-linux.yml
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu118
GPU_ARCH_VERSION: 11.8
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main
use_split_build: True
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cuda11_8-split
build_environment: linux-binary-manywheel
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.4xlarge.nvidia.gpu
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_9-cuda12_1-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
@ -146,6 +101,7 @@ jobs:
GPU_ARCH_VERSION: 12.1
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
use_split_build: False
DESIRED_PYTHON: "3.9"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_9-cuda12_1
@ -169,6 +125,7 @@ jobs:
GPU_ARCH_VERSION: 12.1
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
use_split_build: False
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cuda12_1
build_environment: linux-binary-manywheel
@ -177,53 +134,6 @@ jobs:
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_9-cuda12_1-split-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
needs: get-label-type
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu121
GPU_ARCH_VERSION: 12.1
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
use_split_build: True
DESIRED_PYTHON: "3.9"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_9-cuda12_1-split
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_9-cuda12_1-split-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- manywheel-py3_9-cuda12_1-split-build
- get-label-type
uses: ./.github/workflows/_binary-test-linux.yml
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu121
GPU_ARCH_VERSION: 12.1
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
use_split_build: True
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cuda12_1-split
build_environment: linux-binary-manywheel
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.4xlarge.nvidia.gpu
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_9-cuda12_4-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
@ -238,6 +148,7 @@ jobs:
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main
use_split_build: False
DESIRED_PYTHON: "3.9"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_9-cuda12_4
@ -261,6 +172,7 @@ jobs:
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main
use_split_build: False
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cuda12_4
build_environment: linux-binary-manywheel
@ -268,50 +180,3 @@ jobs:
runs_on: linux.4xlarge.nvidia.gpu
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_9-cuda12_4-split-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
needs: get-label-type
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main
use_split_build: True
DESIRED_PYTHON: "3.9"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_9-cuda12_4-split
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_9-cuda12_4-split-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- manywheel-py3_9-cuda12_4-split-build
- get-label-type
uses: ./.github/workflows/_binary-test-linux.yml
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main
use_split_build: True
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cuda12_4-split
build_environment: linux-binary-manywheel
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.4xlarge.nvidia.gpu
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,182 @@
# @generated DO NOT EDIT MANUALLY
# Template is at: .github/templates/linux_binary_build_workflow.yml.j2
# Generation script: .github/scripts/generate_ci_workflows.py
name: linux-binary-manywheel-split
on:
push:
branches:
- main
tags:
- 'ciflow/periodic/*'
workflow_dispatch:
env:
# Needed for conda builds
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
ANACONDA_USER: pytorch
AWS_DEFAULT_REGION: us-east-1
BINARY_ENV_FILE: /tmp/env
BUILD_ENVIRONMENT: linux-binary-manywheel-split
BUILDER_ROOT: /builder
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.pull_request.number }}
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
PYTORCH_ROOT: /pytorch
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
SKIP_ALL_TESTS: 0
concurrency:
group: linux-binary-manywheel-split-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
jobs:
get-label-type:
name: get-label-type
uses: ./.github/workflows/_runner-determinator.yml
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
manywheel-py3_9-cuda11_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
needs: get-label-type
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu118
GPU_ARCH_VERSION: 11.8
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main
use_split_build: True
DESIRED_PYTHON: "3.9"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_9-cuda11_8
build_environment: linux-binary-manywheel-split
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_9-cuda11_8-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- manywheel-py3_9-cuda11_8-build
- get-label-type
uses: ./.github/workflows/_binary-test-linux.yml
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu118
GPU_ARCH_VERSION: 11.8
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main
use_split_build: True
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cuda11_8
build_environment: linux-binary-manywheel-split
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.4xlarge.nvidia.gpu
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_9-cuda12_1-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
needs: get-label-type
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu121
GPU_ARCH_VERSION: 12.1
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
use_split_build: True
DESIRED_PYTHON: "3.9"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_9-cuda12_1
build_environment: linux-binary-manywheel-split
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_9-cuda12_1-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- manywheel-py3_9-cuda12_1-build
- get-label-type
uses: ./.github/workflows/_binary-test-linux.yml
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu121
GPU_ARCH_VERSION: 12.1
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
use_split_build: True
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cuda12_1
build_environment: linux-binary-manywheel-split
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.4xlarge.nvidia.gpu
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_9-cuda12_4-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
needs: get-label-type
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main
use_split_build: True
DESIRED_PYTHON: "3.9"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_9-cuda12_4
build_environment: linux-binary-manywheel-split
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_9-cuda12_4-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- manywheel-py3_9-cuda12_4-build
- get-label-type
uses: ./.github/workflows/_binary-test-linux.yml
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main
use_split_build: True
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cuda12_4
build_environment: linux-binary-manywheel-split
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.4xlarge.nvidia.gpu
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}

File diff suppressed because it is too large Load Diff

View File

@ -58,6 +58,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main
use_split_build: False
DESIRED_PYTHON: "3.9"
runs_on: linux.s390x
ALPINE_IMAGE: "docker.io/s390x/alpine"
@ -81,6 +82,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main
use_split_build: False
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cpu-s390x
build_environment: linux-s390x-binary-manywheel
@ -103,6 +105,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main
use_split_build: False
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cpu-s390x
secrets:
@ -124,6 +127,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main
use_split_build: False
DESIRED_PYTHON: "3.10"
runs_on: linux.s390x
ALPINE_IMAGE: "docker.io/s390x/alpine"
@ -147,6 +151,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main
use_split_build: False
DESIRED_PYTHON: "3.10"
build_name: manywheel-py3_10-cpu-s390x
build_environment: linux-s390x-binary-manywheel
@ -169,6 +174,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main
use_split_build: False
DESIRED_PYTHON: "3.10"
build_name: manywheel-py3_10-cpu-s390x
secrets:
@ -190,6 +196,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main
use_split_build: False
DESIRED_PYTHON: "3.11"
runs_on: linux.s390x
ALPINE_IMAGE: "docker.io/s390x/alpine"
@ -213,6 +220,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main
use_split_build: False
DESIRED_PYTHON: "3.11"
build_name: manywheel-py3_11-cpu-s390x
build_environment: linux-s390x-binary-manywheel
@ -235,6 +243,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main
use_split_build: False
DESIRED_PYTHON: "3.11"
build_name: manywheel-py3_11-cpu-s390x
secrets:
@ -256,6 +265,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main
use_split_build: False
DESIRED_PYTHON: "3.12"
runs_on: linux.s390x
ALPINE_IMAGE: "docker.io/s390x/alpine"
@ -279,6 +289,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main
use_split_build: False
DESIRED_PYTHON: "3.12"
build_name: manywheel-py3_12-cpu-s390x
build_environment: linux-s390x-binary-manywheel
@ -301,6 +312,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main
use_split_build: False
DESIRED_PYTHON: "3.12"
build_name: manywheel-py3_12-cpu-s390x
secrets:
@ -322,6 +334,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main
use_split_build: False
DESIRED_PYTHON: "3.13"
runs_on: linux.s390x
ALPINE_IMAGE: "docker.io/s390x/alpine"
@ -345,6 +358,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main
use_split_build: False
DESIRED_PYTHON: "3.13"
build_name: manywheel-py3_13-cpu-s390x
build_environment: linux-s390x-binary-manywheel
@ -367,6 +381,7 @@ jobs:
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main
use_split_build: False
DESIRED_PYTHON: "3.13"
build_name: manywheel-py3_13-cpu-s390x
secrets:

View File

@ -49,7 +49,7 @@ jobs:
DESIRED_DEVTOOLSET: cxx11-abi
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
steps:
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the

View File

@ -51,7 +51,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
steps:
- name: Display EC2 information
shell: bash
@ -169,7 +169,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
steps:
- name: Display EC2 information
shell: bash

View File

@ -58,7 +58,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
steps:
- name: Display EC2 information
shell: bash
@ -176,7 +176,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
steps:
- name: Display EC2 information
shell: bash
@ -290,7 +290,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
build_name: libtorch-cpu-shared-with-deps-debug
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
@ -316,7 +316,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
steps:
- name: Display EC2 information
shell: bash
@ -435,7 +435,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
steps:
- name: Display EC2 information
shell: bash
@ -550,7 +550,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
build_name: libtorch-cuda11_8-shared-with-deps-debug
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
@ -576,7 +576,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
steps:
- name: Display EC2 information
shell: bash
@ -695,7 +695,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
steps:
- name: Display EC2 information
shell: bash
@ -810,7 +810,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
build_name: libtorch-cuda12_1-shared-with-deps-debug
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
@ -836,7 +836,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
steps:
- name: Display EC2 information
shell: bash
@ -955,7 +955,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
steps:
- name: Display EC2 information
shell: bash
@ -1070,7 +1070,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
build_name: libtorch-cuda12_4-shared-with-deps-debug
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -51,7 +51,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
steps:
- name: Display EC2 information
shell: bash
@ -169,7 +169,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
steps:
- name: Display EC2 information
shell: bash

View File

@ -58,7 +58,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
steps:
- name: Display EC2 information
shell: bash
@ -176,7 +176,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
steps:
- name: Display EC2 information
shell: bash
@ -290,7 +290,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
build_name: libtorch-cpu-shared-with-deps-release
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
@ -316,7 +316,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
steps:
- name: Display EC2 information
shell: bash
@ -435,7 +435,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
steps:
- name: Display EC2 information
shell: bash
@ -550,7 +550,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
build_name: libtorch-cuda11_8-shared-with-deps-release
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
@ -576,7 +576,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
steps:
- name: Display EC2 information
shell: bash
@ -695,7 +695,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
steps:
- name: Display EC2 information
shell: bash
@ -810,7 +810,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
build_name: libtorch-cuda12_1-shared-with-deps-release
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
@ -836,7 +836,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
steps:
- name: Display EC2 information
shell: bash
@ -955,7 +955,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
steps:
- name: Display EC2 information
shell: bash
@ -1070,7 +1070,7 @@ jobs:
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
DESIRED_PYTHON: "3.9"
build_name: libtorch-cuda12_4-shared-with-deps-release
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -18,11 +18,22 @@ concurrency:
permissions: read-all
jobs:
get-label-type:
name: get-label-type
uses: ./.github/workflows/_runner-determinator.yml
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-focal-cuda12_4-py3_10-gcc9-inductor-build:
# Should be synced with the one in inductor.yml, but this doesn't run inductor_timm
name: cuda12.4-py3.10-gcc9-sm86
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
sync-tag: linux-focal-cuda12_4-py3_10-gcc9-inductor-build
build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86
docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks

View File

@ -0,0 +1,40 @@
name: inductor-micro-benchmark-x86
on:
schedule:
- cron: 0 7 * * *
push:
tags:
- ciflow/inductor-micro-benchmark-cpu-x86/*
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions: read-all
jobs:
linux-jammy-cpu-py3_9-gcc11-inductor-build:
name: linux-jammy-cpu-py3.9-gcc11-inductor
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-jammy-py3.9-gcc11
docker-image-name: pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks
# Use metal host for benchmark jobs
test-matrix: |
{ include: [
{ config: "inductor-micro-benchmark-cpu-x86", shard: 1, num_shards: 1, runner: "linux.24xl.spr-metal" },
]}
linux-jammy-cpu-py3_9-gcc11-inductor-micro-benchmark-test:
name: linux-jammy-cpu-py3.9-gcc11-inductor
uses: ./.github/workflows/_linux-test.yml
needs: linux-jammy-cpu-py3_9-gcc11-inductor-build
with:
build-environment: linux-jammy-py3.9-gcc11
docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }}
use-gha: anything-non-empty-to-use-gha
timeout-minutes: 720

View File

@ -16,10 +16,21 @@ concurrency:
permissions: read-all
jobs:
get-label-type:
name: get-label-type
uses: ./.github/workflows/_runner-determinator.yml
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-focal-cuda12_1-py3_10-gcc9-inductor-micro-benchmark-build:
name: cuda12.1-py3.10-gcc9-sm80
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80
docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '8.0'

View File

@ -13,10 +13,21 @@ concurrency:
permissions: read-all
jobs:
get-label-type:
name: get-label-type
uses: ./.github/workflows/_runner-determinator.yml
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-focal-cuda12_1-py3_10-gcc9-inductor-build:
name: cuda12.1-py3.10-gcc9-sm80
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80
docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '8.0'

View File

@ -68,10 +68,21 @@ concurrency:
permissions: read-all
jobs:
get-label-type:
name: get-label-type
uses: ./.github/workflows/_runner-determinator.yml
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-focal-cuda12_1-py3_10-gcc9-inductor-build:
name: cuda12.1-py3.10-gcc9-sm80
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80
docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '8.0'

View File

@ -50,10 +50,21 @@ concurrency:
permissions: read-all
jobs:
get-label-type:
name: get-label-type
uses: ./.github/workflows/_runner-determinator.yml
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-jammy-aarch64-py3_10-inductor-build:
name: linux-jammy-aarch64-py3.10-inductor
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: linux.arm64.m7g.4xlarge
build-environment: linux-jammy-aarch64-py3.10
docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks

View File

@ -48,10 +48,21 @@ concurrency:
permissions: read-all
jobs:
get-label-type:
name: get-label-type
uses: ./.github/workflows/_runner-determinator.yml
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-jammy-cpu-py3_9-gcc11-inductor-build:
name: linux-jammy-cpu-py3.9-gcc11-inductor
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-py3.9-gcc11-build
docker-image-name: pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks
test-matrix: |

View File

@ -66,10 +66,21 @@ concurrency:
permissions: read-all
jobs:
get-label-type:
name: get-label-type
uses: ./.github/workflows/_runner-determinator.yml
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-focal-cuda12_1-py3_10-gcc9-inductor-build:
name: cuda12.1-py3.10-gcc9-sm80
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80
docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '8.0'

View File

@ -18,10 +18,21 @@ concurrency:
permissions: read-all
jobs:
get-label-type:
name: get-label-type
uses: ./.github/workflows/_runner-determinator.yml
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-focal-cuda12_1-py3_10-gcc9-periodic-dynamo-benchmarks-build:
name: cuda12.1-py3.10-gcc9-sm86-periodic-dynamo-benchmarks
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86
docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '8.6'
@ -60,7 +71,9 @@ jobs:
linux-focal-cuda12_1-py3_10-gcc9-inductor-build-gcp:
name: cuda12.1-py3.10-gcc9-sm80
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80
docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '8.0'

View File

@ -22,10 +22,21 @@ concurrency:
permissions: read-all
jobs:
get-label-type:
name: get-label-type
uses: ./.github/workflows/_runner-determinator.yml
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-focal-rocm6_1-py3_8-inductor-build:
name: rocm6.1-py3.8-inductor
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-focal-rocm6.1-py3.8
docker-image-name: pytorch-linux-focal-rocm-n-py3
test-matrix: |

View File

@ -223,7 +223,7 @@ jobs:
cache: pip
- name: Install dependencies
run: |
pip install pytest-rerunfailures==11.1.* pytest-flakefinder==1.1.* pytest-xdist==3.3.* expecttest==0.1.* numpy==1.24.*
pip install pytest-rerunfailures==11.1.* pytest-flakefinder==1.1.* pytest-xdist==3.3.* expecttest==0.2.* fbscribelogger==0.1.* numpy==1.24.*
pip install torch --pre --index-url https://download.pytorch.org/whl/nightly/cpu/
- name: Run run_test.py (nonretryable)
run: |

View File

@ -57,8 +57,10 @@ jobs:
docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9
test-matrix: |
{ include: [
{ config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "nogpu_AVX512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "nogpu_AVX512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
]}
linux-focal-cuda12_1-py3_10-gcc9-test:
@ -87,8 +89,10 @@ jobs:
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
{ config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "nogpu_AVX512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "nogpu_AVX512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
]}
@ -333,8 +337,10 @@ jobs:
docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9
test-matrix: |
{ include: [
{ config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "nogpu_AVX512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "nogpu_AVX512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
]}

View File

@ -3,18 +3,12 @@ name: rocm
on:
push:
branches:
# - main
- main
- release/*
tags:
- ciflow/rocm/*
workflow_dispatch:
schedule:
# We have several schedules so jobs can check github.event.schedule to activate only for a fraction of the runs.
# Also run less frequently on weekends.
- cron: 45 0,8,16 * * 1-5
- cron: 45 4 * * 0,6
- cron: 45 4,12,20 * * 1-5
- cron: 45 12 * * 0,6
- cron: 29 8 * * * # about 1:29am PDT
concurrency:

View File

@ -56,12 +56,14 @@ jobs:
cuda-arch-list: 8.6
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "default", shard: 2, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "default", shard: 3, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "default", shard: 4, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "default", shard: 5, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "default", shard: 6, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "default", shard: 1, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "default", shard: 2, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "default", shard: 3, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "default", shard: 4, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "default", shard: 5, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "default", shard: 6, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "default", shard: 7, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "default", shard: 8, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
]}
linux-focal-cuda12_1-py3-gcc9-slow-gradcheck-test:
@ -87,8 +89,9 @@ jobs:
cuda-arch-list: 8.6
test-matrix: |
{ include: [
{ config: "slow", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "slow", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "slow", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "slow", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "slow", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
]}
linux-focal-cuda12_1-py3_10-gcc9-sm86-test:

View File

@ -10,8 +10,18 @@ permissions:
contents: read
jobs:
get-label-type:
name: get-label-type
uses: ./.github/workflows/_runner-determinator.yml
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
index:
runs-on: linux.g5.4xlarge.nvidia.gpu # 1 GPU A10G 24GB each
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" # 1 GPU A10G 24GB each
environment: target-determinator-env
steps:
- name: Clone PyTorch

View File

@ -11,10 +11,21 @@ concurrency:
cancel-in-progress: true
jobs:
get-label-type:
name: get-label-type
uses: ./.github/workflows/_runner-determinator.yml
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-focal-cuda12_1-py3_10-gcc9-torchbench-build-gcp:
name: cuda12.1-py3.10-gcc9-sm80
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80
docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '8.0'

View File

@ -266,8 +266,10 @@ jobs:
docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9
test-matrix: |
{ include: [
{ config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "nogpu_AVX512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "nogpu_AVX512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },

View File

@ -2,7 +2,7 @@ name: Upload test stats
on:
workflow_run:
workflows: [pull, trunk, periodic, inductor, unstable, slow, unstable-periodic, inductor-periodic, rocm, inductor-micro-benchmark, inductor-cu124, inductor-rocm]
workflows: [pull, trunk, periodic, inductor, unstable, slow, unstable-periodic, inductor-periodic, rocm, inductor-micro-benchmark, inductor-micro-benchmark-x86, inductor-cu124, inductor-rocm]
types:
- completed
@ -96,7 +96,7 @@ jobs:
python3 -m tools.stats.check_disabled_tests --workflow-run-id "${WORKFLOW_RUN_ID}" --workflow-run-attempt "${WORKFLOW_RUN_ATTEMPT}" --repo "${REPO_FULLNAME}"
- name: Upload gpt-fast benchmark results to Rockset
if: steps.upload-s3.outcome && steps.upload-s3.outcome == 'success' && github.event.workflow_run.name == 'inductor-micro-benchmark'
if: steps.upload-s3.outcome && steps.upload-s3.outcome == 'success' && contains('inductor-micro-benchmark', github.event.workflow_run.name)
env:
ROCKSET_API_KEY: ${{ secrets.ROCKSET_API_KEY }}
WORKFLOW_RUN_ID: ${{ github.event.workflow_run.id }}

View File

@ -138,7 +138,7 @@ init_command = [
'--dry-run={{DRYRUN}}',
'numpy==1.24.3 ; python_version == "3.8"',
'numpy==1.26.0 ; python_version >= "3.9"',
'expecttest==0.1.6',
'expecttest==0.2.1',
'mypy==1.10.0',
'sympy==1.12.1 ; python_version == "3.8"',
'sympy==1.13.0 ; python_version >= "3.9"',
@ -210,6 +210,8 @@ include_patterns = [
'aten/src/ATen/native/nested/*.h',
'c10/**/*.cpp',
'c10/**/*.h',
'caffe2/**/*.cc',
'caffe2/**/*.h',
'torch/*.h',
'torch/csrc/*.h',
'torch/csrc/*.cpp',

View File

@ -65,7 +65,6 @@ cxx_library(
"caffe2/serialize/file_adapter.cc",
"caffe2/serialize/inline_container.cc",
"caffe2/serialize/istream_adapter.cc",
"caffe2/serialize/read_adapter_interface.cc",
],
visibility = ["PUBLIC"],
deps = [

View File

@ -332,6 +332,7 @@ intern_build_aten_ops(
"@fbgemm",
"@mkl",
"@sleef",
"@mkl_dnn//:mkl-dnn",
],
)
@ -472,7 +473,6 @@ filegroup(
"caffe2/serialize/file_adapter.cc",
"caffe2/serialize/inline_container.cc",
"caffe2/serialize/istream_adapter.cc",
"caffe2/serialize/read_adapter_interface.cc",
],
)

View File

@ -57,7 +57,6 @@ nn/qat/ @jerryzh168
# Docker
/.ci/docker/ @jeffdaily
/.ci/docker/ci_commit_pins/triton.txt @desertfire @Chillee @eellison @shunting314 @bertmaher @jeffdaily @jataylo @jithunnair-amd @pruthvistony
/.ci/docker/ci_commit_pins/triton-rocm.txt @jeffdaily @jataylo @jithunnair-amd @pruthvistony
/.ci/docker/ci_commit_pins/triton-xpu.txt @EikanWang @gujinghui
# Github Actions

View File

@ -50,6 +50,7 @@ Following is the Release Compatibility Matrix for PyTorch releases:
| PyTorch version | Python | Stable CUDA | Experimental CUDA | Stable ROCm |
| --- | --- | --- | --- | --- |
| 2.5 | >=3.9, <=3.12, (3.13 experimental) | CUDA 11.8, CUDA 12.1, CUDA 12.4, CUDNN 9.1.0.70 | None | ROCm 6.2 |
| 2.4 | >=3.8, <=3.12 | CUDA 11.8, CUDA 12.1, CUDNN 9.1.0.70 | CUDA 12.4, CUDNN 9.1.0.70 | ROCm 6.1 |
| 2.3 | >=3.8, <=3.11, (3.12 experimental) | CUDA 11.8, CUDNN 8.7.0.84 | CUDA 12.1, CUDNN 8.9.2.26 | ROCm 6.0 |
| 2.2 | >=3.8, <=3.11, (3.12 experimental) | CUDA 11.8, CUDNN 8.7.0.84 | CUDA 12.1, CUDNN 8.9.2.26 | ROCm 5.7 |

View File

@ -299,6 +299,15 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND5( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
#define AT_DISPATCH_FLOATING_TYPES_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
@ -307,6 +316,26 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
#define AT_DISPATCH_FLOATING_TYPES_AND5( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
TYPE, \
NAME, \
...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_FLOATING_TYPES_AND5( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
__VA_ARGS__))
#define AT_DISPATCH_CASE_COMPLEX_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__)

View File

@ -707,7 +707,12 @@ bool are_all_mutations_under_no_grad_or_inference_mode(const Tensor& functional_
}
bool isFunctionalTensor(const at::Tensor& tensor) {
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize);
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize);
}
bool isBaseTensor(const at::Tensor& tensor) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(tensor));
return unsafeGetFunctionalWrapper(tensor)->isBaseTensor();
}
bool isFunctionalTensor(const std::optional<Tensor>& t) {

View File

@ -165,6 +165,12 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
was_storage_changed_ = true;
}
// A FunctionalTensor is considered a base if its not a view of another
// tensor.
bool isBaseTensor() const {
return view_metas_.empty();
}
c10::SymInt get_storage_size(bool before) {
return functional_storage_impl()->get_storage_size(before);
}
@ -290,6 +296,8 @@ TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(
return functional_impl;
}
TORCH_API bool isBaseTensor(const at::Tensor& tensor);
TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
TORCH_API bool isFunctionalTensor(const std::optional<Tensor>& t);
TORCH_API bool isFunctionalTensor(

View File

@ -69,7 +69,7 @@ thread_local std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES>
at::ScalarType::Undefined, // Vulkan
at::ScalarType::Undefined, // Metal
at::kHalf, // XPU
at::ScalarType::Undefined, // MPS
at::kHalf, // MPS
at::ScalarType::Undefined, // Meta (tensors with no data)
at::kBFloat16, // HPU / HABANA
at::ScalarType::Undefined, // SX-Aurora / NEC
@ -206,6 +206,118 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
}
TORCH_LIBRARY_IMPL(_, AutocastMPS, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
TORCH_LIBRARY_IMPL(aten, AutocastMPS, m) {
// lower_precision_fp
KERNEL_MPS2(_convolution, deprecated, lower_precision_fp)
KERNEL_MPS(_convolution, lower_precision_fp)
KERNEL_MPS(conv1d, lower_precision_fp)
KERNEL_MPS(conv2d, lower_precision_fp)
KERNEL_MPS(conv_tbc, lower_precision_fp)
KERNEL_MPS(conv_transpose1d, lower_precision_fp)
KERNEL_MPS2(conv_transpose2d, input, lower_precision_fp)
KERNEL_MPS(convolution, lower_precision_fp)
KERNEL_MPS(_mps_convolution, lower_precision_fp)
KERNEL_MPS(prelu, lower_precision_fp)
KERNEL_MPS(addmm, lower_precision_fp)
KERNEL_MPS(addmv, lower_precision_fp)
KERNEL_MPS(addr, lower_precision_fp)
KERNEL_MPS(matmul, lower_precision_fp)
KERNEL_MPS(einsum, lower_precision_fp)
KERNEL_MPS(mm, lower_precision_fp)
KERNEL_MPS(mv, lower_precision_fp)
KERNEL_MPS(linear, lower_precision_fp)
KERNEL_MPS(addbmm, lower_precision_fp)
KERNEL_MPS(baddbmm, lower_precision_fp)
KERNEL_MPS(bmm, lower_precision_fp)
KERNEL_MPS(chain_matmul, lower_precision_fp)
KERNEL_MPS(linalg_multi_dot, lower_precision_fp)
KERNEL_MPS(lstm_cell, lower_precision_fp)
// fp32
KERNEL_MPS(acos, fp32)
KERNEL_MPS(asin, fp32)
KERNEL_MPS(cosh, fp32)
KERNEL_MPS(erfinv, fp32)
KERNEL_MPS(exp, fp32)
KERNEL_MPS(expm1, fp32)
KERNEL_MPS(log, fp32)
KERNEL_MPS(log10, fp32)
KERNEL_MPS(log2, fp32)
KERNEL_MPS(log1p, fp32)
KERNEL_MPS(reciprocal, fp32)
KERNEL_MPS(rsqrt, fp32)
KERNEL_MPS(sinh, fp32)
KERNEL_MPS(tan, fp32)
KERNEL_MPS2(pow, Tensor_Scalar, fp32)
KERNEL_MPS2(pow, Tensor_Tensor, fp32)
KERNEL_MPS2(pow, Scalar, fp32)
KERNEL_MPS(softplus, fp32)
KERNEL_MPS(layer_norm, fp32)
KERNEL_MPS(native_layer_norm, fp32)
KERNEL_MPS(group_norm, fp32)
KERNEL_MPS2(frobenius_norm, dim, fp32)
KERNEL_MPS(nuclear_norm, fp32)
KERNEL_MPS2(nuclear_norm, dim, fp32)
KERNEL_MPS(batch_norm, fp32)
KERNEL_MPS(cosine_similarity, fp32)
KERNEL_MPS(poisson_nll_loss, fp32)
KERNEL_MPS(cosine_embedding_loss, fp32)
KERNEL_MPS(nll_loss, fp32)
KERNEL_MPS(nll_loss2d, fp32)
KERNEL_MPS(hinge_embedding_loss, fp32)
KERNEL_MPS(kl_div, fp32)
KERNEL_MPS(l1_loss, fp32)
KERNEL_MPS(smooth_l1_loss, fp32)
KERNEL_MPS(huber_loss, fp32)
KERNEL_MPS(mse_loss, fp32)
KERNEL_MPS(margin_ranking_loss, fp32)
KERNEL_MPS(multilabel_margin_loss, fp32)
KERNEL_MPS(soft_margin_loss, fp32)
KERNEL_MPS(triplet_margin_loss, fp32)
KERNEL_MPS(multi_margin_loss, fp32)
KERNEL_MPS(binary_cross_entropy_with_logits, fp32)
KERNEL_MPS(dist, fp32)
KERNEL_MPS(pdist, fp32)
KERNEL_MPS(cdist, fp32)
KERNEL_MPS(renorm, fp32)
KERNEL_MPS(logsumexp, fp32)
// fp32_set_opt_dtype
KERNEL_MPS(prod, fp32)
KERNEL_MPS2(prod, dim_int, fp32)
KERNEL_MPS2(prod, dim_Dimname, fp32)
KERNEL_MPS2(softmax, int, fp32)
KERNEL_MPS2(softmax, Dimname, fp32)
KERNEL_MPS2(log_softmax, int, fp32)
KERNEL_MPS2(log_softmax, Dimname, fp32)
KERNEL_MPS(cumprod, fp32)
KERNEL_MPS2(cumprod, dimname, fp32)
KERNEL_MPS(cumsum, fp32)
KERNEL_MPS2(cumsum, dimname, fp32)
KERNEL_MPS(linalg_vector_norm, fp32)
KERNEL_MPS(linalg_matrix_norm, fp32)
KERNEL_MPS2(linalg_matrix_norm, str_ord, fp32)
KERNEL_MPS(sum, fp32)
KERNEL_MPS2(sum, dim_IntList, fp32)
KERNEL_MPS2(sum, dim_DimnameList, fp32)
//
// promote
KERNEL_MPS(addcdiv, promote)
KERNEL_MPS(addcmul, promote)
KERNEL_MPS(atan2, promote)
KERNEL_MPS(bilinear, promote)
KERNEL_MPS(cross, promote)
KERNEL_MPS(dot, promote)
KERNEL_MPS(grid_sampler, promote)
KERNEL_MPS(index_put, promote)
KERNEL_MPS(tensordot, promote)
KERNEL_MPS(scatter_add, promote)
}
TORCH_LIBRARY_IMPL(_, AutocastCPU, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}

View File

@ -145,6 +145,8 @@ inline bool is_autocast_eligible(
return tensor.is_xla() && tensor.is_floating_point();
case c10::DeviceType::PrivateUse1:
return tensor.is_privateuseone() && tensor.is_floating_point();
case c10::DeviceType::MPS:
return tensor.is_mps() && tensor.is_floating_point();
default:
return false;
}
@ -168,6 +170,8 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type(
return DispatchKey::AutocastXLA;
case c10::DeviceType::PrivateUse1:
return DispatchKey::AutocastPrivateUse1;
case c10::DeviceType::MPS:
return DispatchKey::AutocastMPS;
default:
throw std::runtime_error(
"unknown device type for autocast in get_autocast_dispatch_key_from_device_type");
@ -178,7 +182,7 @@ inline bool is_autocast_available(c10::DeviceType device_type) {
if (device_type == at::kCPU || device_type == at::kCUDA ||
device_type == at::kXPU || device_type == at::kIPU ||
device_type == at::kHPU || device_type == at::kXLA ||
device_type == at::kPrivateUse1) {
device_type == at::kPrivateUse1 || device_type == at::kMPS) {
return true;
} else {
return false;
@ -745,6 +749,27 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
REDISPATCH_SIGNATURE, \
POLICY)
// KERNEL_MPS registration for AutocastMPS
#define KERNEL_MPS(OP, POLICY) \
m.impl( \
TORCH_SELECTIVE_NAME("aten::" #OP), \
&WrapFunction< \
CastPolicy::POLICY, \
DeviceType::MPS, \
decltype(ATEN_FN(OP)), \
decltype(ATEN_FN(OP)), \
&ATEN_FN(OP)>::type::call);
#define KERNEL_MPS2(OP, OVERLOAD, POLICY) \
m.impl( \
TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \
&WrapFunction< \
CastPolicy::POLICY, \
DeviceType::MPS, \
decltype(ATEN_FN2(OP, OVERLOAD)), \
decltype(ATEN_FN2(OP, OVERLOAD)), \
&ATEN_FN2(OP, OVERLOAD)>::type::call);
// Op lists for different policies.
// To make sure other backends can reuse the policy op list.
#define AT_FORALL_LOWER_PRECISION_FP(_) \

View File

@ -228,6 +228,7 @@ namespace c10 {
_(aten, is_autocast_cpu_enabled) \
_(aten, is_autocast_xla_enabled) \
_(aten, get_autocast_dtype) \
_(aten, is_autocast_mps_enabled) \
FORALL_ATEN_BASE_SYMBOLS(_) \
_(onnx, Add) \
_(onnx, Concat) \

View File

@ -9,7 +9,7 @@
#endif
namespace at::cpu {
bool is_cpu_support_avx2() {
bool is_avx2_supported() {
#if !defined(__s390x__) && !defined(__powerpc__)
return cpuinfo_initialize() && cpuinfo_has_x86_avx2();
#else
@ -17,7 +17,7 @@ bool is_cpu_support_avx2() {
#endif
}
bool is_cpu_support_avx512() {
bool is_avx512_supported() {
#if !defined(__s390x__) && !defined(__powerpc__)
return cpuinfo_initialize() && cpuinfo_has_x86_avx512f() && cpuinfo_has_x86_avx512vl() && cpuinfo_has_x86_avx512bw() && cpuinfo_has_x86_avx512dq();
#else
@ -25,7 +25,7 @@ bool is_cpu_support_avx512() {
#endif
}
bool is_cpu_support_avx512_vnni() {
bool is_avx512_vnni_supported() {
#if !defined(__s390x__) && !defined(__powerpc__)
return cpuinfo_initialize() && cpuinfo_has_x86_avx512vnni();
#else
@ -33,7 +33,15 @@ bool is_cpu_support_avx512_vnni() {
#endif
}
bool is_cpu_support_amx_tile() {
bool is_avx512_bf16_supported() {
#if !defined(__s390x__) && !defined(__powerpc__)
return cpuinfo_initialize() && cpuinfo_has_x86_avx512bf16();
#else
return false;
#endif
}
bool is_amx_tile_supported() {
#if !defined(__s390x__) && !defined(__powerpc__)
return cpuinfo_initialize() && cpuinfo_has_x86_amx_tile();
#else
@ -42,7 +50,7 @@ bool is_cpu_support_amx_tile() {
}
bool init_amx() {
if (!is_cpu_support_amx_tile()) {
if (!is_amx_tile_supported()) {
return false;
}

View File

@ -6,14 +6,17 @@
namespace at::cpu {
TORCH_API bool is_cpu_support_avx2();
TORCH_API bool is_cpu_support_avx512();
TORCH_API bool is_avx2_supported();
TORCH_API bool is_avx512_supported();
// Detect if CPU support Vector Neural Network Instruction.
TORCH_API bool is_cpu_support_avx512_vnni();
TORCH_API bool is_avx512_vnni_supported();
// Detect if CPU supports AVX512_BF16 ISA
TORCH_API bool is_avx512_bf16_supported();
// Detect if CPU support Advanced Matrix Extension.
TORCH_API bool is_cpu_support_amx_tile();
TORCH_API bool is_amx_tile_supported();
// Enable the system to use AMX instructions.
TORCH_API bool init_amx();

View File

@ -636,6 +636,21 @@ inline void transpose_mxn<float, 8, 8>(
_mm256_storeu_ps(&dst[7 * ld_dst], th);
}
template<>
inline void transpose_mxn<float, 16, 16>(
const float* src,
int64_t ld_src,
float* dst,
int64_t ld_dst) {
transpose_mxn<float, 8, 8>(
src , ld_src, dst, ld_dst);
transpose_mxn<float, 8, 8>(
src + 8, ld_src, dst + 8 * ld_dst, ld_dst);
transpose_mxn<float, 8, 8>(
src + 8 * ld_src, ld_src, dst + 8, ld_dst);
transpose_mxn<float, 8, 8>(
src + 8 * ld_src + 8, ld_src, dst + 8 * ld_dst + 8, ld_dst);
}
#endif
}} // namespace at::vec::CPU_CAPABILITY

View File

@ -582,8 +582,7 @@ Vectorized<float> inline fmsub(const Vectorized<float>& a, const Vectorized<floa
// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L230-L304
// kernel for transposing mxn where m, n <= 16
// M + (M + 1) / 2 * 2 + (M + 3) / 4 * 4 + (M + 7) / 8 * 8 + 2 * N instructions
template <>
inline void transpose_mxn<float>(const float* src, int64_t ld_src, float* dst, int64_t ld_dst, int M, int N) {
inline void transpose_mxn_16x16(const float* src, int64_t ld_src, float* dst, int64_t ld_dst, int M, int N) {
TORCH_CHECK(M <= 16 && N <= 16, "transpose_mxn<float> expects M, N <= 16.");
// load from src to registers
__m512 input[16];
@ -667,8 +666,39 @@ inline void transpose_mxn<float>(const float* src, int64_t ld_src, float* dst, i
}
}
template<>
inline void transpose_mxn<float>(const float* src, int64_t ld_src, float* dst, int64_t ld_dst, int M, int N) {
int64_t i = 0;
for (; i < M / 16 * 16; i += 16) {
int64_t j = 0;
for (; j < N / 16 * 16; j += 16) {
transpose_mxn_16x16(
src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, 16, 16);
}
// handle remainder j
int nrem = N - j;
if (nrem > 0) {
transpose_mxn_16x16(
src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, 16, nrem);
}
}
// handle remainder i
int mrem = M - i;
if (mrem > 0) {
int j = 0;
for (; j < N / 16 * 16; j += 16) {
transpose_mxn_16x16(
src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, 16);
}
// handle remainder j
int nrem = N - j;
transpose_mxn_16x16(
src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, nrem);
}
}
template <typename T, int M, int N,
typename std::enable_if_t<std::is_same<T, float>::value && M <= 16 && N <= 16, int> = 0>
typename std::enable_if_t<std::is_same<T, float>::value, int> = 0>
inline void transpose_mxn(const float* src, int64_t ld_src, float* dst, int64_t ld_dst) {
transpose_mxn<float>(src, ld_src, dst, ld_dst, M, N);
}

View File

@ -1408,7 +1408,6 @@ void scaled_gemm(
const void *result_scale_ptr,
int64_t result_ld,
ScalarType result_dtype,
void* amax_ptr,
bool use_fast_accum) {
#if CUDA_VERSION >= 11080 || defined(USE_ROCM)
const auto computeType = CUBLAS_COMPUTE_32F;
@ -1421,13 +1420,9 @@ void scaled_gemm(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60200)
// Amax support in ROCm as of 6.2
if (isFloat8Type(result_dtype)) {
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, amax_ptr);
if (result_scale_ptr != nullptr) {
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
}
#endif
#ifndef USE_ROCM
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fastAccuMode);
#endif

View File

@ -140,7 +140,6 @@ void scaled_gemm(
const void* result_scale_ptr,
int64_t result_ld,
ScalarType result_dtype,
void* amax_ptr,
bool use_fast_accum);
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \

View File

@ -188,7 +188,10 @@ TuningResultsValidator::TuningResultsValidator() {
RegisterValidator(
"ROCM_VERSION",
[rocm_version]() { return rocm_version; },
[rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; });
[rocm_version](auto&& k) {
TUNABLE_LOG1("ROCM_VERSION validation: expect ", k, " to match ", rocm_version);
return rocm_version == k ? OK : FAIL;
});
}
// gfx arch
{
@ -196,7 +199,10 @@ TuningResultsValidator::TuningResultsValidator() {
RegisterValidator(
"GCN_ARCH_NAME",
[gcn_arch_name]() { return gcn_arch_name; },
[gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; });
[gcn_arch_name](auto&& k) {
TUNABLE_LOG1("GCN_ARCH_NAME validation: expect ", k, " to match ", gcn_arch_name);
return gcn_arch_name == k ? OK : FAIL;
});
}
// rocblas
{
@ -212,7 +218,10 @@ TuningResultsValidator::TuningResultsValidator() {
RegisterValidator(
"ROCBLAS_VERSION",
[rocblas_version]() { return rocblas_version; },
[rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; });
[rocblas_version](auto&& k) {
TUNABLE_LOG1("ROCBLAS_VERSION validation: expect ", k, " to match ", rocblas_version);
return rocblas_version == k ? OK : FAIL;
});
}
// hipblaslt
{
@ -226,7 +235,10 @@ TuningResultsValidator::TuningResultsValidator() {
RegisterValidator(
"HIPBLASLT_VERSION",
[hipblaslt_version]() { return hipblaslt_version; },
[hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; });
[hipblaslt_version](auto&& k) {
TUNABLE_LOG1("HIPBLASLT_VERSION validation: expect ", k, " to match ", hipblaslt_version);
return hipblaslt_version == k ? OK : FAIL;
});
}
#endif
}

View File

@ -104,7 +104,6 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
params->c_scale_ptr,
params->ldc,
params->c_dtype,
params->amax_ptr,
params->use_fast_accum);
return OK;
}

View File

@ -23,6 +23,9 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
OP_DECOMPOSE(dropout_);
OP_DECOMPOSE(feature_alpha_dropout_);
OP_DECOMPOSE(feature_dropout_);
OP_DECOMPOSE(dropout);
OP_DECOMPOSE(_scaled_dot_product_attention_math);
OP_DECOMPOSE(scaled_dot_product_attention);
}
static void unsupportedData(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
@ -235,7 +238,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE(relu6_);
OP_DECOMPOSE(prelu);
OP_DECOMPOSE2(softmax, int);
OP_DECOMPOSE(scaled_dot_product_attention);
OP_DECOMPOSE(special_gammainc);
OP_DECOMPOSE(special_gammaincc);
OP_DECOMPOSE(special_logit);
@ -261,7 +263,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE(special_xlogy);
OP_DECOMPOSE2(special_xlogy, other_scalar);
OP_DECOMPOSE2(special_xlogy, self_scalar);
OP_DECOMPOSE(_scaled_dot_product_attention_math);
m.impl("split.sizes", native::split_symint);
@ -386,6 +387,11 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE2(to, dtype);
OP_DECOMPOSE2(to, dtype_layout);
OP_DECOMPOSE2(to, other);
// Random ops that are also registered here
OP_DECOMPOSE(dropout);
OP_DECOMPOSE(_scaled_dot_product_attention_math);
OP_DECOMPOSE(scaled_dot_product_attention);
}
} // namespace at::functorch

View File

@ -496,6 +496,11 @@ _scaled_dot_product_flash_attention_batch_rule(
bool return_debug_mask,
c10::optional<double> scale
) {
if (dropout_p > 0) {
auto maybe_layer = maybeCurrentDynamicLayer();
RandomnessType randomness = maybe_layer->randomness();
check_randomness(randomness, query_bdim.has_value() || key_bdim.has_value() || value_bdim.has_value());
}
auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim);
auto query_ = moveBatchDimToFront(query, query_bdim);
auto key_ = moveBatchDimToFront(key, key_bdim);
@ -540,6 +545,11 @@ fourOutputs _scaled_dot_product_efficient_attention_batch_rule(
bool is_causal,
c10::optional<double> scale
) {
if (dropout_p > 0) {
auto maybe_layer = maybeCurrentDynamicLayer();
RandomnessType randomness = maybe_layer->randomness();
check_randomness(randomness, query_bdim.has_value() || key_bdim.has_value() || value_bdim.has_value());
}
auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim);
auto query_ = moveBatchDimToFront(query, query_bdim);
auto key_ = moveBatchDimToFront(key, key_bdim);
@ -577,6 +587,11 @@ _scaled_dot_product_cudnn_attention_batch_rule(
bool return_debug_mask,
c10::optional<double> scale
) {
if (dropout_p > 0) {
auto maybe_layer = maybeCurrentDynamicLayer();
RandomnessType randomness = maybe_layer->randomness();
check_randomness(randomness, query_bdim.has_value() || key_bdim.has_value() || value_bdim.has_value());
}
auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim);
auto query_ = moveBatchDimToFront(query, query_bdim);
auto key_ = moveBatchDimToFront(key, key_bdim);

View File

@ -28,6 +28,7 @@ MPSStream::MPSStream(Stream stream) : _stream(stream) {
_executionDescriptor.enableCommitAndContinue = _enableCommitAndContinue;
// Choose level which optimizes for GPU
[_compilationDescriptor disableTypeInference];
_compilationDescriptor.optimizationLevel = MPSGraphOptimizationLevel0;
_executionDescriptor.compilationDescriptor = _compilationDescriptor;
}

View File

@ -41,6 +41,17 @@ extern "C" void zaxpy_(int *n, void *a, const void *x, int *incx, void *y, int *
#include <fbgemm/FbgemmI64.h>
#endif // USE_FBGEMM
#if AT_MKLDNN_ENABLED()
#include <oneapi/dnnl/dnnl_version.h>
#endif // oneDNN
#define ONEDNN_UKERNEL_ENABLED (DNNL_VERSION_MAJOR >=3 && DNNL_VERSION_MINOR >=5)
#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
#include <oneapi/dnnl/dnnl_ukernel.hpp>
#include <oneapi/dnnl/dnnl.hpp>
#endif // oneDNN BRGEMM
namespace at::native::cpublas {
namespace internal {
@ -822,4 +833,350 @@ void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<fl
n, x, incx, y, incy);
}
} // namespace at::native::cpublas
// oneDNN BRGEMM
#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
struct BrgemmKey {
int64_t M;
int64_t N;
int64_t K;
int64_t batch_size;
int64_t lda;
int64_t ldb;
int64_t ldc;
ScalarType dt_a;
ScalarType dt_b;
ScalarType dt_c;
float alpha;
float beta;
BrgemmKey(
int64_t M,
int64_t N,
int64_t K,
int64_t batch_size,
int64_t lda,
int64_t ldb,
int64_t ldc,
ScalarType dt_a,
ScalarType dt_b,
ScalarType dt_c,
float alpha,
float beta)
: M(M),
N(N),
K(K),
batch_size(batch_size),
lda(lda),
ldb(ldb),
ldc(ldc),
dt_a(dt_a),
dt_b(dt_b),
dt_c(dt_c),
alpha(alpha),
beta(beta) {}
bool operator==(const BrgemmKey& other) const {
return M == other.M && N == other.N && K == other.K &&
batch_size == other.batch_size && lda == other.lda &&
ldb == other.ldb && ldc == other.ldc && dt_a == other.dt_a &&
dt_b == other.dt_b && dt_c == other.dt_c && alpha == other.alpha &&
beta == other.beta;
}
};
struct PackKey {
int64_t K;
int64_t N;
int64_t ld_in;
int64_t ld_out;
ScalarType dt_in;
ScalarType dt_out;
PackKey(
int64_t K,
int64_t N,
int64_t ld_in,
int64_t ld_out,
ScalarType dt_in,
ScalarType dt_out)
: K(K),
N(N),
ld_in(ld_in),
ld_out(ld_out),
dt_in(dt_in),
dt_out(dt_out) {}
bool operator==(const PackKey& other) const {
return N == other.N && K == other.K && ld_in == other.ld_in &&
ld_out == other.ld_out && dt_in == other.dt_in &&
dt_out == other.dt_out;
}
};
inline dnnl::memory::data_type get_dnnl_dtype(ScalarType dtype) {
if (dtype == ScalarType::Float) {
return dnnl::memory::data_type::f32;
} else if (dtype == ScalarType::BFloat16) {
return dnnl::memory::data_type::bf16;
} else if (dtype == ScalarType::Half) {
return dnnl::memory::data_type::f16;
} else if (dtype == ScalarType::Byte) {
return dnnl::memory::data_type::u8;
} else if (dtype == ScalarType::Char) {
return dnnl::memory::data_type::s8;
} else {
TORCH_CHECK(false, "get_dnnl_dtype expects float/bfloat16/half/int8 tensor input");
}
}
template<typename key_t>
struct UnsafeUkernelKeyHasher {
std::size_t operator()(const key_t& key) const;
};
template<>
std::size_t UnsafeUkernelKeyHasher<BrgemmKey>::operator()(const BrgemmKey& key) const {
// Use beta, M, N, and K to compute hash to reduce the overhead as
// batch size, alpha, and data types are unlikely to change within the same kernel and
// leading dimensions are likely to be related to M, K, N or use fixed values.
std::size_t h = std::hash<float>()(key.beta + 1);
h = std::hash<int64_t>()(key.M) ^ (h << 1);
h = std::hash<int64_t>()(key.N) ^ (h << 1);
h = std::hash<int64_t>()(key.K) ^ (h << 1);
h = std::hash<int64_t>()(key.ldc) ^ (h << 1);
return h;
}
template<>
std::size_t UnsafeUkernelKeyHasher<PackKey>::operator()(const PackKey& key) const {
// Use K and N to compute hash to reduce the overhead as
// data types are unlikely to change and
// ld_in/ld_out is likely to be related to K, N or use fixed values
std::size_t h = std::hash<int64_t>()(key.K);
h = std::hash<int64_t>()(key.N) ^ (h << 1);
return h;
}
template <typename key_t, typename value_t>
struct KernelCache {
using kstore_t = std::unordered_map<key_t, std::shared_ptr<value_t>, UnsafeUkernelKeyHasher<key_t>>;
static inline std::shared_ptr<value_t>&& fetch_or_create(
const key_t& key,
const std::function<std::shared_ptr<value_t>()>& callback) {
auto&& search = get_store().find(key);
if (search != get_store().end()) {
return std::move(search->second);
} else {
get_store().insert({key, callback()});
return std::move(get_store()[key]);
}
}
static inline kstore_t& get_store() {
static thread_local kstore_t cache_kernels;
return cache_kernels;
}
};
// Helper struct for convenient brgemm configuration
struct GemmHelper {
GemmHelper(
int64_t M,
int64_t N,
int64_t K,
int64_t bs,
int64_t ld_a,
int64_t ld_b,
int64_t ld_c,
ScalarType dt_a,
ScalarType dt_b,
ScalarType dt_c,
const float alpha,
const float beta) {
// Create brgemm
brg = dnnl::ukernel::brgemm(
M,
N,
K,
bs,
ld_a,
ld_b,
ld_c,
get_dnnl_dtype(dt_a),
get_dnnl_dtype(dt_b),
get_dnnl_dtype(dt_c),
alpha,
beta);
// Create a scratchpad buffer for the brgemm execution
scratchpad = std::vector<uint8_t>(brg.get_scratchpad_size());
// Prepare default vector of pairs of tensors A and B offsets for each batch.
A_B_offsets.reserve(1);
A_B_offsets[0] = std::make_pair(0, 0);
}
dnnl::ukernel::brgemm brg;
std::vector<uint8_t> scratchpad;
std::vector<std::pair<int64_t, int64_t>> A_B_offsets;
};
struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
// Fetch/create GemmHelper object and execute brgemm with batch size = 1
template <typename scalar_t_a, typename scalar_t_b, typename scalar_t_c>
static inline void call(
int64_t M,
int64_t N,
int64_t K,
int64_t ld_a,
int64_t ld_b,
int64_t ld_c,
const float alpha,
const float beta,
const scalar_t_a* A,
const scalar_t_b* B,
scalar_t_c* C) {
auto&& key = BrgemmKey(
M,
N,
K,
int64_t(1),
ld_a,
ld_b,
ld_c,
c10::CppTypeToScalarType<scalar_t_a>::value,
c10::CppTypeToScalarType<scalar_t_b>::value,
c10::CppTypeToScalarType<scalar_t_c>::value,
alpha,
beta);
// Fetch/create GemmHelper object
auto&& value = fetch_or_create(key, [&]() {
auto&& v = std::make_shared<GemmHelper>(
M,
N,
K,
1,
ld_a,
ld_b,
ld_c,
c10::CppTypeToScalarType<scalar_t_a>::value,
c10::CppTypeToScalarType<scalar_t_b>::value,
c10::CppTypeToScalarType<scalar_t_c>::value,
alpha,
beta);
(*v).brg.generate();
return std::move(v);
});
if (get_current() != value) {
dnnl::ukernel::brgemm::release_hw_context();
((*value).brg).set_hw_context();
get_current() = value;
}
((*value).brg)
.execute(A, B, (*value).A_B_offsets, C, (*value).scratchpad.data());
}
static inline std::shared_ptr<GemmHelper>& get_current() {
static thread_local std::shared_ptr<GemmHelper> current;
return current;
}
static inline bool device_check(ScalarType dtype) {
if (!at::globalContext().userEnabledMkldnn()) {
return false;
}
if (dtype == ScalarType::Half) {
static bool fp16_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_fp16;
return fp16_support;
}
return false;
}
};
using pack_t = dnnl::ukernel::brgemm_pack_B;
struct Pack : public KernelCache <PackKey, pack_t> {
static inline void call(
int64_t K,
int64_t N,
int64_t ld_in,
int64_t ld_out,
ScalarType dt_in,
ScalarType dt_out,
const void* in,
void* out) {
auto&& key = PackKey(K, N, ld_in, ld_out, dt_in, dt_out);
auto&& pack = fetch_or_create(key, [&]() {
auto&& p = std::make_shared<pack_t>(
K, N, ld_in, ld_out, get_dnnl_dtype(dt_in), get_dnnl_dtype(dt_out));
if (need_pack(dt_in)) {
(*p).generate();
}
return std::move(p);
});
if (need_pack(dt_in)) {
(*pack).execute(in, out);
} else {
TORCH_CHECK(false, "No need to pack");
}
}
static inline bool need_pack(ScalarType dtype) {
if (!at::globalContext().userEnabledMkldnn()) {
return false;
}
if (dtype == ScalarType::Half) {
static bool fp16_pack = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_amx_fp16;
return fp16_pack;
}
return false;
}
};
#endif
void brgemm(
int64_t M,
int64_t N,
int64_t K,
int64_t ld_a,
int64_t ld_b,
int64_t ld_c,
const float alpha,
const float beta,
const at::Half* A,
const at::Half* B,
float* C) {
#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
if (Brgemm::device_check(ScalarType::Half)) {
Brgemm::call<at::Half, at::Half, float>(
M, N, K, ld_a, ld_b, ld_c, alpha, beta, A, B, C);
return;
}
#endif
TORCH_CHECK(false,
"Half Brgemm is only supported on X64 when oneDNN ukernel is enabled and avx512_fp16 is supported");
}
void brgemm_release() {
#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
dnnl::ukernel::brgemm::release_hw_context();
#endif
}
void pack(
int64_t K,
int64_t N,
int64_t ld_in,
int64_t ld_out,
ScalarType dt_in,
ScalarType dt_out,
const void* in,
void* out) {
#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
Pack::call(K, N, ld_in, ld_out, dt_in, dt_out, in, out);
#else
TORCH_CHECK(false, "pack is only supported on X64 with oneDNN ukernel enabled");
#endif
}
bool need_pack(ScalarType dt_in) {
#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
return Pack::need_pack(dt_in);
#else
return false;
#endif
}
} // namespace at::native::cpublas

View File

@ -7,6 +7,7 @@
#include <c10/core/ScalarType.h>
#include <c10/core/Scalar.h>
namespace at::native::cpublas {
namespace internal {
@ -186,4 +187,40 @@ void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy);
void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
} // namespace at::native::cpublas
// Batch-reduce GEMM
// Operates by the following formula:
// C = alpha * SUM(A[i] x B[i]) + beta * C, i = 0 to batch size
// A Base pointer to a tensor A.
// B Base pointer to a tensor B.
// C Pointer to a tensor C (accumulation buffer).
TORCH_API void brgemm(
int64_t M,
int64_t N,
int64_t K,
int64_t ld_a,
int64_t ld_b,
int64_t ld_c,
const float alpha,
const float beta,
const at::Half* A,
const at::Half* B,
float* C);
// Release brgemm hardware context
void brgemm_release();
// Pack B matrix to get better performance if needed
void pack(
int64_t K,
int64_t N,
int64_t ld_in,
int64_t ld_out,
ScalarType dt_in,
ScalarType dt_out,
const void* in,
void* out);
// Whether pack is needed in the platform.
bool need_pack(ScalarType dt_in);
} // namespace at::native::cpublas

View File

@ -144,7 +144,7 @@ static void col2im_out_cpu_template(
output.resize_({batch_size, n_output_plane, output_height, output_width});
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf,
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kBFloat16, kHalf, kBool,
input.scalar_type(), "col2im_out_cpu", [&] {
Tensor input_n = Tensor();
Tensor output_n = Tensor();

View File

@ -421,12 +421,18 @@ struct ConvParams {
// cudnn and miopen are guaranteed not to be on mobile, and T102591915 / T110194934 suggest
// that maybe the compiledWithCuDNN() check sometimes segfaults (though I can't imagine how)
#if !defined(C10_MOBILE)
if (needs_64bit_indexing_no_split(input, weight)) {
return false;
}
if (!detail::getCUDAHooks().compiledWithCuDNN()) {
return false;
}
if (needs_64bit_indexing_no_split(input, weight)) {
static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) {
TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions"
" if the V8 API is not enabled or before cuDNN version 9.3+."
" Consider upgrading cuDNN and/or enabling the V8 API for better efficiency.");
return false;
}
}
if (!input.is_cuda() || !cudnn_enabled) {
return false;
}

View File

@ -94,7 +94,7 @@ static void im2col_out_cpu_template(
output.resize_({batch_size, n_output_plane, output_length});
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf,
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kBFloat16, kHalf, kBool,
input.scalar_type(), "im2col_out_cpu", [&] {
Tensor input_n;
Tensor output_n;

View File

@ -19,6 +19,7 @@
#include <ATen/native/ReduceOpsUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/native/mkldnn/Matmul.h>
#include <ATen/native/mkldnn/Utils.h>
#include <c10/core/GradMode.h>
#include <c10/util/accumulate.h>
#include <c10/util/irange.h>
@ -1358,13 +1359,8 @@ static inline int64_t get_mkldnn_matmul_min_dim() {
static auto value = [&] {
const int64_t default_min_dim = [&] {
// Minimum dimension requirement for MKLDNN; derived based on experiments.
// By default, it's only enabled on Neoverse V1.
#if !defined(__s390x__) && !defined(__powerpc__)
if (cpuinfo_initialize() && cpuinfo_get_uarchs_count() == 1 && cpuinfo_get_uarch(0)->uarch == cpuinfo_uarch_neoverse_v1) {
return 8;
}
#endif
return 0;
//it's enabled on all Neoverse cpus.
return is_arm_neoverse() ? 8 : 0;
}();
const char* ptr = std::getenv("TORCH_MKLDNN_MATMUL_MIN_DIM");
return ptr != nullptr ? std::atoi(ptr) : default_min_dim;
@ -1377,13 +1373,8 @@ static inline int64_t get_mkldnn_matmul_min_size() {
static auto value = [&] {
const int64_t default_min_size = [&] {
// Minimum size requirement for MKLDNN; derived based on experiments.
// By default, it's only enabled on Neoverse V1.
#if !defined(__s390x__) && !defined(__powerpc__)
if (cpuinfo_initialize() && cpuinfo_get_uarchs_count() == 1 && cpuinfo_get_uarch(0)->uarch == cpuinfo_uarch_neoverse_v1) {
return 8 * 1024;
}
#endif
return 0;
// it's enabled on all Neoverse cpus.
return is_arm_neoverse() ? 8 * 1024 : 0;
}();
const char* ptr = std::getenv("TORCH_MKLDNN_MATMUL_MIN_SIZE");
return ptr != nullptr ? std::atoi(ptr) : default_min_size;

View File

@ -209,7 +209,13 @@ std::tuple<Tensor,Tensor> batch_norm_cpu_update_stats_template(
bool all_contiguous = is_contiguous(input);
constexpr bool mixed_type = !std::is_same_v<scalar_t, param_t>;
const auto dtype = mixed_type ? kFloat : input.scalar_type();
// Using float data type for Half _var_sum in batchnorm stats updating on CPU
// to avoid _var_sum overflow since the representation range of Half is small.
using opmath_t = std::conditional_t<std::is_same_v<param_t, at::Half>, at::opmath_type<param_t>, param_t>;
auto dtype = mixed_type ? kFloat : input.scalar_type();
if (dtype == kHalf) {
dtype = kFloat;
}
auto save_mean_a = save_mean.accessor<param_t, 1>();
auto save_var_transform_a = save_var_transform.accessor<param_t, 1>();
@ -220,9 +226,9 @@ std::tuple<Tensor,Tensor> batch_norm_cpu_update_stats_template(
if (all_contiguous) {
auto _mean = at::empty({n_input}, input.options().dtype(dtype));
auto _var_sum = at::empty({n_input}, input.options().dtype(dtype));
auto _mean_a = _mean.accessor<param_t, 1>();
auto _var_sum_a = _var_sum.accessor<param_t, 1>();
auto momentum_ = static_cast<param_t>(momentum);
auto _mean_a = _mean.accessor<opmath_t, 1>();
auto _var_sum_a = _var_sum.accessor<opmath_t, 1>();
auto momentum_ = static_cast<opmath_t>(momentum);
batch_norm_cpu_collect_stats_stub(kCPU, _mean, _var_sum, input);

View File

@ -284,7 +284,7 @@ void resize_bytes_nocuda(const Storage& storage, const c10::SymInt& newsize) {
} else if (device_type == at::kPrivateUse1) {
at::detail::getPrivateUse1Hooks().resizePrivateUse1Bytes(
storage, newsize.expect_int());
} else if (device_type == at::kXPU || device_type == at::kHPU) {
} else if (device_type == at::kXPU || device_type == at::kHPU || device_type == at::kMTIA) {
ptrdiff_t size_bytes_i = newsize.expect_int();
TORCH_CHECK(
!c10::overflows<int64_t>(size_bytes_i),

View File

@ -11,18 +11,18 @@ namespace ao {
namespace sparse {
namespace {
constexpr int64_t serialization_version_index = 0;
constexpr int64_t bias_index = 1;
constexpr int64_t out_features_block_size_index = 2;
constexpr int64_t in_features_block_size_index = 3;
constexpr int64_t weight_scales_index = 4;
constexpr int64_t weight_zero_point_index = 5;
constexpr int64_t quantization_scheme_index = 6;
constexpr int64_t row_block_indices_index = 7;
constexpr int64_t col_block_indices_index = 8;
constexpr int64_t weight_values_index = 9;
constexpr int64_t num_output_channels_index = 10;
constexpr int64_t num_input_channels_index = 11;
constexpr int64_t serialization_version_index [[maybe_unused]] = 0;
constexpr int64_t bias_index [[maybe_unused]] = 1;
constexpr int64_t out_features_block_size_index [[maybe_unused]] = 2;
constexpr int64_t in_features_block_size_index [[maybe_unused]] = 3;
constexpr int64_t weight_scales_index [[maybe_unused]] = 4;
constexpr int64_t weight_zero_point_index [[maybe_unused]] = 5;
constexpr int64_t quantization_scheme_index [[maybe_unused]] = 6;
constexpr int64_t row_block_indices_index [[maybe_unused]] = 7;
constexpr int64_t col_block_indices_index [[maybe_unused]] = 8;
constexpr int64_t weight_values_index [[maybe_unused]] = 9;
constexpr int64_t num_output_channels_index [[maybe_unused]] = 10;
constexpr int64_t num_input_channels_index [[maybe_unused]] = 11;
template <typename TENSOR_DTYPE, typename VEC_DTYPE>
std::vector<VEC_DTYPE> unwrap_vector(at::Tensor tensor) {

View File

@ -81,6 +81,12 @@ void atan2_kernel(TensorIteratorBase& iter) {
}
#if !defined(C10_MOBILE)
#define _AT_DISPATCH_INTEGRAL_TYPES_V2(TYPE, NAME, ...) \
AT_DISPATCH_V2( \
TYPE, \
NAME, \
AT_WRAP(__VA_ARGS__), \
AT_EXPAND(AT_INTEGRAL_TYPES_V2))
#define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \
AT_DISPATCH_V2( \
TYPE, \
@ -104,6 +110,8 @@ void atan2_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
#else
#define _AT_DISPATCH_INTEGRAL_TYPES_V2(TYPE, NAME, ...) \
AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, __VA_ARGS__)
#define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
kComplexHalf, kHalf, kBool, kBFloat16, TYPE, NAME, __VA_ARGS__)
@ -382,7 +390,7 @@ void bitwise_and_kernel(TensorIteratorBase& iter) {
if (iter.dtype() == ScalarType::Bool) {
cpu_kernel(iter, [](bool a, bool b) { return a && b; });
} else {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_and_cpu", [&]() {
_AT_DISPATCH_INTEGRAL_TYPES_V2(iter.dtype(), "bitwise_and_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t { return a & b; },
@ -395,7 +403,7 @@ void bitwise_or_kernel(TensorIteratorBase& iter) {
if (iter.dtype() == ScalarType::Bool) {
cpu_kernel(iter, [](bool a, bool b) { return a || b; });
} else {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_or_cpu", [&]() {
_AT_DISPATCH_INTEGRAL_TYPES_V2(iter.dtype(), "bitwise_or_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t { return a | b; },
@ -410,7 +418,7 @@ void bitwise_xor_kernel(TensorIteratorBase& iter) {
// this operation for both Boolean and integral types.
cpu_kernel(iter, [](bool a, bool b) { return a != b; });
} else {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_xor_cpu", [&]() {
_AT_DISPATCH_INTEGRAL_TYPES_V2(iter.dtype(), "bitwise_xor_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t { return a ^ b; },

View File

@ -16,7 +16,6 @@
#else
#include <ATen/ops/empty.h>
#endif
namespace at::native {
namespace {
@ -202,7 +201,97 @@ void reshape_attn_mask_to_4d(
.expand({attn_mask_size_0, attn_mask_size_1, qSize, kvSize});
}
template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_split_size>
template <typename scalar_t>
inline void copy_value_with_pad(
const scalar_t* value_ptr,
scalar_t* dst_ptr,
int64_t rows,
int64_t cols,
int64_t prows,
int64_t pcols,
int64_t ldi) {
auto vec_size = at::vec::Vectorized<scalar_t>::size();
int64_t i = 0;
for (; i < rows; i++) {
int64_t j = 0;
for (; j < cols - (cols % vec_size); j += vec_size) {
auto vec_v =
at::vec::Vectorized<scalar_t>::loadu(value_ptr + i * ldi + j);
vec_v.store(dst_ptr + i * pcols + j);
}
if (j < cols) {
auto vec_v = at::vec::Vectorized<scalar_t>::loadu(
value_ptr + i * ldi + j, cols - j);
vec_v.store(dst_ptr + i * pcols + j, cols - j);
}
// col padding
auto psize = pcols - cols;
if (psize > 0) {
auto zero_vec = at::vec::Vectorized<scalar_t>(0);
int64_t pj = 0;
for (; pj < psize - (psize % vec_size); pj += vec_size) {
zero_vec.store(dst_ptr + i * pcols + cols + pj);
}
if (pj < psize) {
zero_vec.store(dst_ptr + i * pcols + cols + pj, psize - pj);
}
}
}
// row padding
for (; i < prows; i++) {
auto zero_vec = at::vec::Vectorized<scalar_t>(0);
int64_t j = 0;
for (; j < pcols - (pcols % vec_size); j += vec_size) {
zero_vec.store(dst_ptr + i * pcols + j);
}
if (j < pcols) {
zero_vec.store(dst_ptr + i * pcols + j, pcols - j);
}
}
}
template <typename scalar_t>
inline void pad_remain_row_col_zero(
scalar_t* value_ptr,
int rows,
int cols,
int prows,
int pcols,
int ldi) {
auto psize = pcols - cols;
if (psize == 0 && prows == rows) {
return;
}
auto vec_size = at::vec::Vectorized<scalar_t>::size();
auto zero = at::vec::Vectorized<scalar_t>(0);
if (psize > 0) {
for (int i = 0; i < rows; i++) {
int j = 0;
for (; j < psize - (psize % vec_size); j += vec_size) {
zero.store(value_ptr + i * ldi + cols + j);
}
if (j < psize) {
zero.store(value_ptr + i * ldi + cols + j, psize - j);
}
}
}
for (int i = rows; i < prows; i++) {
int j = 0;
for (; j < pcols - (pcols % vec_size); j += vec_size) {
zero.store(value_ptr + i * ldi + j);
}
if (j < pcols) {
zero.store(value_ptr + i * ldi + j, pcols - j);
}
}
}
template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_split_size, bool with_pack=false>
void cpu_flash_attention(
const Tensor& output,
const Tensor& logsumexp,
@ -278,21 +367,70 @@ void cpu_flash_attention(
int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
int64_t qSlice = (qSize - 1) / qSplitSize + 1;
int64_t qSlice = (qSize + qSplitSize - 1) / qSplitSize;
int64_t kvSlice = (kvSize + kvSplitSize - 1) / kvSplitSize;
int64_t kvTail = (kvSize - 1) % kvSplitSize + 1;
int64_t num_thread = at::get_num_threads();
const auto dtype = query.scalar_type();
const auto accumulate_dtype = toOpMathType(dtype);
// Whether pack is needed
bool need_pack = false;
// Block size of packing B matrix
int64_t packb_size = 64;
// Use packb_size due to the limitation:
// oneDNN pack only supports output leading dimention being one of (16, 32, 48, 64)
// For instance,
// for q @ k.T [qSplitSize, headSize] * [headSize, kvSplitSize] = [qSplitSize, kvSplitSize],
// we need to split kvSplitSize with packb_size for packing k.T,
// for (q @ k.T) @ v [qSplitSize, kvSplitSize] x [kvSplitSize, headSize] -> [qSplitSize, headSize],
// we need to split headSize with packb_size for packing v
// TODO Simplify the check when oneDNN supports fused pack with transpose and has better performance
if (with_pack) {
need_pack = num_head >= 4 && headSize % packb_size == 0 && kvSize >= packb_size;
if (need_pack) {
float pack_size = batchSize * num_head * kvSize * headSize / 1024;
float gemm_size_per_thread =
(batchSize * num_head * qSlice + num_thread - 1) / num_thread *
qSplitSize * (is_causal ? qSize : kvSize) * headSize / 1024;
float gsize = gemm_size_per_thread / pack_size;
// When the number of gemm is much greater than the number of pack,
// the pack and padding overhead can be overlaped.
if (pack_size < 2688) {
need_pack = gsize >= 36 || (gsize >= 24 && headSize > packb_size);
} else if (pack_size < 16384) {
need_pack = gsize >= (is_causal ? 54 : 52);
} else {
need_pack = gsize >= (is_causal ? 54 : 40);
}
}
}
int64_t rHeadSize = need_pack ? (headSize + packb_size - 1) / packb_size * packb_size : headSize;
int64_t rkvSplitSize = need_pack ? (kvSplitSize + packb_size - 1) / packb_size * packb_size : kvSplitSize;
int64_t rkvTail = need_pack ? (kvTail + packb_size - 1) / packb_size * packb_size : kvTail;
int64_t rkvSize = kv_split_size > kvSize ? rkvTail : rkvSplitSize * kvSlice + rkvTail;
// oneDNN pack does not support odd K now, we need also pad odd K
bool headSize_even = headSize % 2 == 0;
int64_t eheadSize = need_pack && !headSize_even ? headSize + 1: headSize;
int64_t ekvSplitSize = need_pack && (kvSplitSize % 2 != 0) ? kvSplitSize + 1 : kvSplitSize;
int64_t ekvTail = need_pack && (kvTail % 2 != 0) ? kvTail + 1 : kvTail;
// allocate per thread temp buf (accumulate type)
int64_t size_per_thread =
/* qk */ qSplitSize * kvSplitSize +
/* qk */ qSplitSize * rkvSplitSize +
/* qk_max */ qSplitSize +
/* qk_sum */ qSplitSize +
/* dst */ qSplitSize * headSize;
/* dst */ qSplitSize * rHeadSize;
at::Tensor buf = at::empty({num_thread, size_per_thread}, query.options().dtype(accumulate_dtype));
at::Tensor buf_reduced = at::empty({num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0}, query.options());
at::Tensor buf_reduced = at::empty(
{num_thread,
qSplitSize,
is_reduced_type ? ekvSplitSize : 0},
query.options());
// Data ptrs
const scalar_t* q_data = query.const_data_ptr<scalar_t>();
@ -306,16 +444,128 @@ void cpu_flash_attention(
accum_t* buf_data = buf.data_ptr<accum_t>();
scalar_t* buf_reduced_data = is_reduced_type ? buf_reduced.data_ptr<scalar_t>() : nullptr;
// Buffer to store padding query
scalar_t* query_padding_ptr = nullptr;
std::unique_ptr<scalar_t[]> query_padding_data;
if (!headSize_even && need_pack) {
query_padding_data = std::make_unique<scalar_t[]>(num_thread * qSplitSize * eheadSize);
query_padding_ptr = query_padding_data.get();
}
// Buffer to store Key and Value after transforms
scalar_t* key_reorder_ptr = nullptr;
std::unique_ptr<scalar_t[]> key_reorder_data;
scalar_t* value_reorder_ptr = nullptr;
std::unique_ptr<scalar_t[]> value_reorder_data;
int kv_padding_size = (kvSize - 1) / kvSplitSize * ekvSplitSize + ekvTail;
if (need_pack) {
key_reorder_data = std::make_unique<scalar_t[]>(batchSize * num_head * eheadSize * rkvSize);
key_reorder_ptr = key_reorder_data.get();
value_reorder_data = std::make_unique<scalar_t[]>(batchSize * num_head * kv_padding_size * rHeadSize);
value_reorder_ptr = value_reorder_data.get();
}
// Reorder K, V
if (need_pack) {
at::parallel_for(0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) {
int64_t i = 0, j = 0, l = 0, n = 0;
at::native::data_index_init(begin, i, batchSize, j, num_head, l, kvSlice);
std::unique_ptr<scalar_t[]> transpose_buffer = std::make_unique<scalar_t[]>(eheadSize * packb_size);
scalar_t* transpose_buffer_ptr = transpose_buffer.get();
std::unique_ptr<scalar_t[]> v_copy_buffer = std::make_unique<scalar_t[]>(ekvSplitSize * packb_size);
scalar_t* v_copy_buffer_ptr = v_copy_buffer.get();
for (const auto z : c10::irange(begin, end)) {
(void)z; // Suppress unused variable
n = l * kvSplitSize;
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
int64_t ekvBlockSize = kvBlockSize % 2 == 0 ? kvBlockSize : kvBlockSize + 1;
// Split kvSplitSize with packb_size
// [kvSplitSize, headSize] -> [div_up(kvSplitSize, packb_size), packb_size, headSize]
// Transpose [packb_size, headSize] -> [headSize, packb_size]
// Pack transposed buffer
for (int64_t b = 0; b < kvBlockSize; b += packb_size) {
bool tail = kvBlockSize - b < packb_size;
// TODO Use fused pack with transpose support when oneDNN supports such usage
utils::transpose<uint16_t>(
tail ? kvBlockSize - b : packb_size,
headSize,
/* src_ptr */
reinterpret_cast<const uint16_t*>(
k_data + i * kStrideB + j * kStrideH + n * kStrideN +
b * kStrideN),
/* ld_src */ kStrideN,
/* dst */ reinterpret_cast<uint16_t*>(transpose_buffer_ptr),
/* ld_dst */ packb_size);
// Pad [headSize, x] -> [eheadSize, x]
if (!headSize_even) {
pad_remain_row_col_zero<scalar_t>(
transpose_buffer_ptr,
headSize,
packb_size,
eheadSize,
packb_size,
packb_size);
}
// Pack
cpublas::pack(
/* K */ eheadSize,
/* N */ packb_size,
/* ld_in */ packb_size,
/* ld_out */ packb_size,
/* dt_in */ dtype,
/* dt_out */ dtype,
transpose_buffer_ptr,
key_reorder_ptr + i * num_head * eheadSize * rkvSize +
j * eheadSize * rkvSize + n * eheadSize + b * eheadSize);
}
// Split headSize with packb_size
// [kvSplitSize, headSize] -> [kvSplitSize, div_up(headSize, packb_size), packb_size]
for (int64_t b = 0; b < headSize; b += packb_size) {
// Do copy due to the limitation of input_ld of oneDNN pack:
// Regarding packing [K, N], only input_ld == N is supported
// TODO: remove the copy when pack supports input_ld >= N
copy_value_with_pad<scalar_t>(
v_data + i * vStrideB + j * vStrideH + n * vStrideN + b,
v_copy_buffer_ptr,
kvBlockSize,
(headSize - b < packb_size) ? headSize - b : packb_size,
ekvBlockSize,
packb_size,
vStrideN);
cpublas::pack(
ekvBlockSize,
packb_size,
packb_size,
packb_size,
dtype,
dtype,
v_copy_buffer_ptr,
value_reorder_ptr +
i * num_head * kv_padding_size * rHeadSize +
j * kv_padding_size * rHeadSize + n * rHeadSize +
ekvBlockSize * b);
}
// Move to the next query
at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice);
}
});
}
at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) {
int64_t i = 0, j = 0, k = 0;
data_index_init(begin, i, batchSize, j, num_head, k, qSlice);
int ompIdx = at::get_thread_num();
accum_t* buf_ptr = buf_data + ompIdx * size_per_thread;
accum_t* qk_data = buf_ptr;
accum_t* qk_max_data = qk_data + qSplitSize * kvSplitSize;
accum_t* qk_max_data = qk_data + qSplitSize * rkvSplitSize;
accum_t* qk_sum_data = qk_max_data + qSplitSize;
accum_t* dst_data = qk_sum_data + qSplitSize;
scalar_t* qk_reduced_data = is_reduced_type ? buf_reduced_data + ompIdx * qSplitSize * kvSplitSize : nullptr;
scalar_t* qk_reduced_data = is_reduced_type ? buf_reduced_data + ompIdx * qSplitSize * ekvSplitSize : nullptr;
scalar_t* query_t_padding_ptr = (!headSize_even && need_pack)
? query_padding_ptr + ompIdx * qSplitSize * eheadSize
: nullptr;
for (const auto z : c10::irange(begin, end)) {
(void)z; // Suppress unused variable
@ -327,10 +577,46 @@ void cpu_flash_attention(
fill_stub(qk_sum_data,
static_cast<accum_t>(0), qBlockSize);
int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize;
if (!headSize_even && need_pack) {
// Pad query if headSize is not even
// [qBlockSize, headSize] -> [qBlockSize, eheadSize]
copy_value_with_pad<scalar_t>(
q_data + i * qStrideB + j * qStrideH + m * qStrideM,
query_t_padding_ptr,
qBlockSize,
headSize,
qBlockSize,
eheadSize,
qStrideM
);
}
for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
int64_t ekvBlockSize = (need_pack && kvBlockSize % 2 != 0) ? kvBlockSize + 1 : kvBlockSize;
int64_t rkvBlockSize = kvBlockSize == kvSplitSize ? rkvSplitSize : rkvTail;
// Calculate scale * q @ k.T
cpublas::gemm(
if (need_pack) {
if constexpr (std::is_same_v<scalar_t, at::Half>) {
for (int64_t b = 0; b < kvBlockSize; b += packb_size) {
cpublas::brgemm(
qBlockSize,
packb_size,
eheadSize,
headSize_even ? qStrideM : eheadSize,
packb_size,
rkvBlockSize,
1.f,
0.f,
!headSize_even
? query_t_padding_ptr
: q_data + i * qStrideB + j * qStrideH + m * qStrideM,
key_reorder_ptr + i * num_head * eheadSize * rkvSize +
j * eheadSize * rkvSize + n * eheadSize + b * eheadSize,
qk_data + b);
}
}
} else {
cpublas::gemm(
TransposeType::Transpose,
TransposeType::NoTranspose,
kvBlockSize,
@ -346,11 +632,12 @@ void cpu_flash_attention(
static_cast<accum_t>(0),
qk_data,
kvBlockSize);
}
// Apply causal mask, fill unused with -inf
if (is_causal && num_keys - n <= kvSplitSize) {
for (const auto row : c10::irange(qBlockSize)) {
int64_t last_col = m + row - n;
accum_t* row_ptr = qk_data + row * kvBlockSize;
accum_t* row_ptr = qk_data + row * rkvBlockSize;
fill_stub(row_ptr + last_col + 1,
-std::numeric_limits<accum_t>::infinity(),
kvBlockSize - last_col - 1);
@ -363,29 +650,29 @@ void cpu_flash_attention(
for (int64_t row = 0; row < qBlockSize; ++row) {
#if __GNUC__ == 11 && __GNUC_MINOR__ >= 4 && defined(__ARM_FEATURE_SVE)
_scale_attn_mask_fusion_kernel(
qk_data + row * kvBlockSize,
qk_data + row * rkvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM + (mStrideN == 0 ? 0 : n),
kvBlockSize,
qk_data + row * kvBlockSize,
qk_data + row * rkvBlockSize,
scaling_factor,
mStrideN == 0);
#else
if (mStrideN == 0) {
_scale_attn_mask_fusion_kernel</*is_stride_0*/ true>(
qk_data + row * kvBlockSize,
qk_data + row * rkvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM,
kvBlockSize,
qk_data + row * kvBlockSize,
qk_data + row * rkvBlockSize,
scaling_factor);
} else {
_scale_attn_mask_fusion_kernel</*is_stride_0*/ false>(
qk_data + row * kvBlockSize,
qk_data + row * rkvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM + n,
kvBlockSize,
qk_data + row * kvBlockSize,
qk_data + row * rkvBlockSize,
scaling_factor);
}
#endif
@ -398,28 +685,28 @@ void cpu_flash_attention(
// max per row
tmp_max = at::vec::reduce_all<accum_t>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
qk_data + row * kvBlockSize,
qk_data + row * rkvBlockSize,
kvBlockSize);
} else {
// apply scaling factor and max per row in fusion
_mul_reduce_max_fusion_kernel(
qk_data + row * kvBlockSize,
qk_data + row * rkvBlockSize,
scaling_factor,
kvBlockSize,
qk_data + row * kvBlockSize,
qk_data + row * rkvBlockSize,
tmp_max);
}
tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max;
if (tmp_max == -std::numeric_limits<accum_t>::infinity()) {
// to avoid `nan = exp2f(-inf - (-inf))`
fill_stub(conditional_data_ptr(qk_data, qk_reduced_data) + row * kvBlockSize,
fill_stub(conditional_data_ptr(qk_data, qk_reduced_data) + row * ekvBlockSize,
static_cast<scalar_t>(0), kvBlockSize);
} else {
tmp_sum = tmp_max;
// qk <- exp(qk - max) and sum per row
_exp_reduce_sum_fusion_kernel(
qk_data + row * kvBlockSize, kvBlockSize,
conditional_data_ptr(qk_data, qk_reduced_data) + row * kvBlockSize,
qk_data + row * rkvBlockSize, kvBlockSize,
conditional_data_ptr(qk_data, qk_reduced_data) + row * ekvBlockSize,
tmp_sum);
// exp_tmp <- exp(max[row] - max)
exp_tmp = std::exp(qk_max_data[row] - tmp_max);
@ -431,12 +718,40 @@ void cpu_flash_attention(
if (n > 0) {
vec::map<accum_t>(
[exp_tmp](Vec x) { return x * Vec(exp_tmp); },
dst_data + row * headSize, dst_data + row * headSize, headSize);
dst_data + row * rHeadSize,
dst_data + row * rHeadSize,
headSize);
}
}
if (need_pack && kvBlockSize % 2 != 0) {
// Pad: [qSplitSize,kvSplitSize] -> [qSplitSize,kvSplitSize + 1]
*(qk_reduced_data + row * (1 + kvBlockSize) + kvBlockSize) = scalar_t(0);
}
}
// Calculate Softmax(q @ k.T) @ v
cpublas::gemm(
if (need_pack) {
int64_t psize = n / kvSplitSize * ekvSplitSize;
if constexpr (std::is_same_v<scalar_t, at::Half>) {
for (int64_t b = 0; b < headSize; b += packb_size) {
cpublas::brgemm(
qBlockSize,
packb_size,
ekvBlockSize,
ekvBlockSize,
packb_size,
rHeadSize,
1.0,
n == 0 ? 0.f : 1.f,
qk_reduced_data,
value_reorder_ptr +
i * num_head * kv_padding_size * rHeadSize +
j * kv_padding_size * rHeadSize + psize * rHeadSize +
b * ekvBlockSize,
dst_data + b);
}
}
} else {
cpublas::gemm(
TransposeType::NoTranspose,
TransposeType::NoTranspose,
headSize,
@ -451,6 +766,7 @@ void cpu_flash_attention(
n == 0 ? static_cast<accum_t>(0) : static_cast<accum_t>(1),
dst_data,
headSize);
}
}
// dst <- dst / sum[row]
@ -465,7 +781,7 @@ void cpu_flash_attention(
vec::map<scalar_t>(
[sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); },
out_data + i * oStrideB + j * oStrideH + m * oStrideM + row * oStrideM,
dst_data + row * headSize,
dst_data + row * rHeadSize,
headSize);
}
// Store logsumexp for backward
@ -478,7 +794,9 @@ void cpu_flash_attention(
data_index_step(i, batchSize, j, num_head, k, qSlice);
}
});
if (need_pack) {
cpublas::brgemm_release();
}
}
template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_split_size>
@ -826,6 +1144,13 @@ void cpu_flash_attention_backward(
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Half, mask_t, __VA_ARGS__))
#define FLASH_ATTENTION_KERNEL(FNAME, PACK, TYPE1, TYPE2, SEQ1, SEQ2, ...) \
if (PACK) { \
FNAME<TYPE1, TYPE2, SEQ1, SEQ2, true>(__VA_ARGS__); \
} else { \
FNAME<TYPE1, TYPE2, SEQ1, SEQ2>(__VA_ARGS__); \
}
void flash_attention_kernel_impl(
const Tensor& output,
const Tensor& logsumexp,
@ -838,33 +1163,37 @@ void flash_attention_kernel_impl(
std::optional<double> scale) {
auto q_seq_len = query.size(2);
// When q_seq_len and k_seq_len are long enough,
// cpu_flash_attention with pack has better performance.
bool could_pack = (query.scalar_type() == kHalf && cpublas::need_pack(kHalf));
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, query.scalar_type(), "flash_attention", [&] {
if (!attn_mask.has_value()) {
if (q_seq_len >= 768) {
cpu_flash_attention<scalar_t, scalar_t, 256, 512>(
FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, scalar_t, 256, 512,
output, logsumexp, query, key, value,
dropout_p, is_causal, attn_mask, scale);
} else if (q_seq_len >= 192) {
cpu_flash_attention<scalar_t, scalar_t, 64, 512>(
FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, scalar_t, 64, 512,
output, logsumexp, query, key, value,
dropout_p, is_causal, attn_mask, scale);
} else {
cpu_flash_attention<scalar_t, scalar_t, 32, 512>(
FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, scalar_t, 32, 512,
output, logsumexp, query, key, value,
dropout_p, is_causal, attn_mask, scale);
}
} else {
AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "flash_attention_mask", [&]() {
if (q_seq_len >= 768) {
cpu_flash_attention<scalar_t, mask_t, 256, 512>(
FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, mask_t, 256, 512,
output, logsumexp, query, key, value,
dropout_p, is_causal, attn_mask, scale);
} else if (q_seq_len >= 192) {
cpu_flash_attention<scalar_t, mask_t, 64, 512>(
FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, mask_t, 64, 512,
output, logsumexp, query, key, value,
dropout_p, is_causal, attn_mask, scale);
} else {
cpu_flash_attention<scalar_t, mask_t, 32, 512>(
FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, mask_t, 32, 512,
output, logsumexp, query, key, value,
dropout_p, is_causal, attn_mask, scale);
}
@ -873,6 +1202,8 @@ void flash_attention_kernel_impl(
});
}
#undef FLASH_ATTENTION_KERNEL
void flash_attention_backward_kernel_impl(
const at::Tensor& grad_q,
const at::Tensor& grad_k,

View File

@ -159,6 +159,12 @@ inline void transpose<float>(int64_t M, int64_t N, const float* src, int64_t ld_
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
fbgemm::transpose_simd<float>(M, N, src, ld_src, dst, ld_dst);
}
template <>
inline void transpose<uint16_t>(int64_t M, int64_t N, const uint16_t* src, int64_t ld_src, uint16_t* dst, int64_t ld_dst) {
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
fbgemm::transpose_simd<uint16_t>(M, N, src, ld_src, dst, ld_dst);
}
#endif
template <typename index_t, typename F>

View File

@ -964,9 +964,9 @@ ScalingType get_scaling_type(
} // namespace
// Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax
// Computes matrix multiply + bias while applying scaling to input and output matrices
// Scales are only applicable when matrices are of Float8 type and assumbed to be equal to 1.0 by default.
// If output matrix type is 16 or 32-bit type, neither scale_result is applied nor amax is computed.
// If output matrix type is 16 or 32-bit type, scale_result is not applied.
// Known limitations:
// - Only works if mat1 is row-major and mat2 is column-major
// - Only works if matrices sizes are divisible by 32
@ -1068,9 +1068,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
// Some scaled_gemms require an amax to populate lets create one here
Tensor amax = at::empty({0}, mat1.options().dtype(ScalarType::Float));
#ifdef USE_ROCM
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
@ -1126,7 +1123,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
params.c_scale_ptr = scale_result ? scale_result->data_ptr() : nullptr;
params.ldc = args.result_ld;
params.c_dtype = out_dtype_;
params.amax_ptr = amax.data_ptr();
params.use_fast_accum = use_fast_accum;
if (transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
@ -1150,11 +1146,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
else
#endif
{
#if defined(USE_ROCM) && ROCM_VERSION >= 60200
// hipBlasLT requires scaleD to be set to something in order to use AMAX
auto dummy_options = TensorOptions().dtype(kFloat).device(kCUDA);
auto dummy_scale = at::ones(1, dummy_options);
#endif
at::cuda::blas::scaled_gemm(
args.transa,
args.transb,
@ -1172,14 +1163,9 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
bias ? bias->data_ptr(): nullptr,
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
args.result->data_ptr(),
#if defined(USE_ROCM) && ROCM_VERSION >= 60200
scale_result ? scale_result->data_ptr() : dummy_scale.data_ptr(),
#else
scale_result ? scale_result->data_ptr() : nullptr,
#endif
args.result_ld,
out_dtype_,
amax.data_ptr(),
use_fast_accum);
}

View File

@ -102,7 +102,7 @@ void col2im_out_cuda_template(
output.resize_({batch_size, n_output_plane, output_height, output_width});
int64_t output_batch_stride = output.stride(0);
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16,
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kHalf, kBFloat16, kBool,
input.scalar_type(), "col2im_out_cuda", [&] {
int64_t height_col = (output_height + 2 * pad_height -
(dilation_height * (kernel_height - 1) + 1)) /

View File

@ -103,7 +103,7 @@ static void im2col_out_cuda_template(
output.resize_({batch_size, n_output_plane, output_length});
// Launch kernel
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16,
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kHalf, kBFloat16, kBool,
input.scalar_type(), "im2col_out_cuda", [&] {
Tensor input_n;
Tensor output_n;

View File

@ -1092,7 +1092,11 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
}
constexpr int min_values_per_thread = 16;
#ifndef USE_ROCM
constexpr int max_values_per_thread = 256;
#else
constexpr int max_values_per_thread = 1024;
#endif
if (config.values_per_thread() >= block_height * 16 || config.values_per_thread() >= max_values_per_thread) {
// Divide the input across warps in a thread-block, if that leaves at least

View File

@ -22,6 +22,7 @@ void run_cudnn_SDP_fprop(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
Tensor& softmaxstats,
Tensor& o,
Tensor& dropoutseed,
@ -43,6 +44,7 @@ void run_cudnn_SDP_bprop(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
const Tensor& o,
const Tensor& dO,
const Tensor& softmaxstats,
@ -86,9 +88,9 @@ using graph_and_tensors = std::tuple<
std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
std::shared_ptr<fe::graph::Tensor_attributes>, // K,
std::shared_ptr<fe::graph::Tensor_attributes>, // V,
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>>, // Bias
std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale,
// TODO(eqy): additional options
// std::shared_ptr<fe::graph::Tensor_attributes>, // Bias,
// std::shared_ptr<fe::graph::Tensor_attributes>, // SEQ_LEN_Q,
// std::shared_ptr<fe::graph::Tensor_attributes>, // SEQ_LEN_KV,
std::shared_ptr<fe::graph::Tensor_attributes>, // Seed,
@ -104,7 +106,8 @@ using graph_and_tensors_backward = std::tuple<
std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
std::shared_ptr<fe::graph::Tensor_attributes>, // K,
std::shared_ptr<fe::graph::Tensor_attributes>, // V,
std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>>, // Bias,
std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale,
std::shared_ptr<fe::graph::Tensor_attributes>, // Seed,
std::shared_ptr<fe::graph::Tensor_attributes>, // Offset,
std::shared_ptr<fe::graph::Tensor_attributes>, // O,
@ -126,6 +129,8 @@ struct MHAParams {
std::array<int, MAX_MHA_DIM> q_stride;
std::array<int, MAX_MHA_DIM> k_stride;
std::array<int, MAX_MHA_DIM> v_stride;
std::array<int, MAX_MHA_DIM> bias_dim;
std::array<int, MAX_MHA_DIM> bias_stride;
int64_t b;
int64_t h;
int64_t s_q;
@ -135,6 +140,9 @@ struct MHAParams {
double dropout_probability;
bool is_causal;
bool return_softmaxstats;
// might be redundant if we take 0 dim/stride
// as signaling no-bias
bool has_attn_bias;
};
void setMHAParams(
@ -148,6 +156,7 @@ void setMHAParams(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
double dropout_probability,
bool is_causal,
bool return_softmaxstats) {
@ -166,6 +175,7 @@ void setMHAParams(
params.dropout_probability = dropout_probability;
params.is_causal = is_causal;
params.return_softmaxstats = return_softmaxstats;
params.has_attn_bias = attn_bias.has_value();
TORCH_INTERNAL_ASSERT(
q.sizes().size() == MAX_MHA_DIM,
"Q tensor has unexpected number of dims, please report a bug to PyTorch.");
@ -190,6 +200,17 @@ void setMHAParams(
std::copy(k.strides().begin(), k.strides().end(), params.k_stride.begin());
std::copy(v.sizes().begin(), v.sizes().end(), params.v_dim.begin());
std::copy(v.strides().begin(), v.strides().end(), params.v_stride.begin());
// uninit is OK as the struct is memset 0'd
if (params.has_attn_bias) {
std::copy(
attn_bias.value().sizes().begin(),
attn_bias.value().sizes().end(),
params.bias_dim.begin());
std::copy(
attn_bias.value().strides().begin(),
attn_bias.value().strides().end(),
params.bias_stride.begin());
}
}
struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
@ -203,6 +224,7 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
double dropout_probability,
bool is_causal,
bool return_softmaxstats) {
@ -217,6 +239,7 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
q,
k,
v,
attn_bias,
dropout_probability,
is_causal,
return_softmaxstats);
@ -285,6 +308,7 @@ auto build_graph_and_tensors(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
Tensor& softmaxstats,
Tensor& o,
Tensor& dropoutseed,
@ -301,36 +325,6 @@ auto build_graph_and_tensors(
mha_graph->set_io_data_type(dtype)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
auto Q = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim(std::vector<int64_t>(
q.sizes().data(), q.sizes().data() + q.sizes().size()))
.set_stride(fixSizeOneDimStrideSDPA(
q.sizes(),
std::vector<int64_t>(
q.strides().data(),
q.strides().data() + q.strides().size()))));
auto K = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("K")
.set_dim(std::vector<int64_t>(
k.sizes().data(), k.sizes().data() + k.sizes().size()))
.set_stride(fixSizeOneDimStrideSDPA(
k.sizes(),
std::vector<int64_t>(
k.strides().data(),
k.strides().data() + k.strides().size()))));
auto V = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("V")
.set_dim(std::vector<int64_t>(
v.sizes().data(), v.sizes().data() + v.sizes().size()))
.set_stride(fixSizeOneDimStrideSDPA(
v.sizes(),
std::vector<int64_t>(
v.strides().data(),
v.strides().data() + v.strides().size()))));
auto attn_scale =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Attn_scale")
@ -338,11 +332,6 @@ auto build_graph_and_tensors(
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
// TODO(eqy): support bias in the future in a follow-up PR
// auto bias = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("bias")
// .set_dim({b, 1, s_q, s_kv})
// .set_stride({s_q * s_kv, s_q * s_kv, s_kv, 1}));
auto seed = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seed")
.set_dim({1, 1, 1, 1})
@ -360,11 +349,30 @@ auto build_graph_and_tensors(
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale)
.set_dropout(dropout_probability, seed, offset);
// Optional bias in flash attention is only supported 8.9.3 onwards
if (cudnnGetVersion() >= 8904) {
// scaled_dot_product_flash_attention_options.set_alibi_mask(true);
auto Q = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim(q.sizes().vec())
.set_stride(fixSizeOneDimStrideSDPA(q.sizes(), q.strides().vec())));
auto K = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("K")
.set_dim(k.sizes().vec())
.set_stride(fixSizeOneDimStrideSDPA(k.sizes(), k.strides().vec())));
auto V = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("V")
.set_dim(v.sizes().vec())
.set_stride(fixSizeOneDimStrideSDPA(v.sizes(), v.strides().vec())));
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>> bias;
if (attn_bias.has_value()) {
bias =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim(attn_bias.value().sizes().vec())
.set_stride(attn_bias.value().strides().vec()));
scaled_dot_product_flash_attention_options.set_bias(bias.value());
}
auto seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seq_q")
.set_dim({b, 1, 1, 1})
@ -376,20 +384,9 @@ auto build_graph_and_tensors(
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
// if (cudnnGetVersion() >= 8903) {
// scaled_dot_product_flash_attention_options.set_bias(bias)
// .set_padding_mask(true)
// .set_seq_len_q(seq_q)
// .set_seq_len_kv(seq_kv);
// }
auto [O, Stats] =
mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options);
O->set_output(true)
.set_dim(std::vector<int64_t>(
o.sizes().data(), o.sizes().data() + o.sizes().size()))
.set_stride(std::vector<int64_t>(
o.strides().data(), o.strides().data() + o.strides().size()));
O->set_output(true).set_dim(o.sizes().vec()).set_stride(o.strides().vec());
if (Stats) {
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT);
@ -407,6 +404,7 @@ auto build_graph_and_tensors(
std::move(Q),
std::move(K),
std::move(V),
std::move(bias),
std::move(attn_scale),
std::move(seed),
std::move(offset),
@ -427,6 +425,7 @@ auto build_graph_and_tensors_backward(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
const Tensor& o,
const Tensor& dO,
const Tensor& softmaxstats,
@ -447,24 +446,6 @@ auto build_graph_and_tensors_backward(
mha_graph->set_io_data_type(dtype)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
auto Q = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim(std::vector<int64_t>(q.sizes().begin(), q.sizes().end()))
.set_stride(
std::vector<int64_t>(q.strides().begin(), q.strides().end())));
auto K = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("K")
.set_dim(std::vector<int64_t>(k.sizes().begin(), k.sizes().end()))
.set_stride(
std::vector<int64_t>(k.strides().begin(), k.strides().end())));
auto V = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("V")
.set_dim(std::vector<int64_t>(v.sizes().begin(), v.sizes().end()))
.set_stride(
std::vector<int64_t>(v.strides().begin(), v.strides().end())));
auto attn_scale =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Attn_scale")
@ -472,6 +453,31 @@ auto build_graph_and_tensors_backward(
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
auto sdpa_backward_options = fe::graph::SDPA_backward_attributes()
.set_name("CUDNN_SDPA_BACKWARD")
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
auto Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim(q.sizes().vec())
.set_stride(q.strides().vec()));
auto K = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim(k.sizes().vec())
.set_stride(k.strides().vec()));
auto V = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim(v.sizes().vec())
.set_stride(v.strides().vec()));
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>> bias;
if (attn_bias.has_value()) {
bias =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim(attn_bias.value().sizes().vec())
.set_stride(attn_bias.value().strides().vec()));
sdpa_backward_options.set_bias(bias.value());
}
auto Seed = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seed")
.set_dim({1, 1, 1, 1})
@ -482,47 +488,27 @@ auto build_graph_and_tensors_backward(
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
auto O = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("O")
.set_dim(std::vector<int64_t>(o.sizes().begin(), o.sizes().end()))
.set_stride(
std::vector<int64_t>(o.strides().begin(), o.strides().end())));
auto STATS = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("Stats")
.set_dim(std::vector<int64_t>(
softmaxstats.sizes().begin(), softmaxstats.sizes().end()))
.set_stride(std::vector<int64_t>(
softmaxstats.strides().begin(), softmaxstats.strides().end()))
.set_data_type(fe::DataType_t::FLOAT));
auto DO = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("DO")
.set_dim(std::vector<int64_t>(dO.sizes().begin(), dO.sizes().end()))
.set_stride(
std::vector<int64_t>(dO.strides().begin(), dO.strides().end())));
auto sdpa_backward_options = fe::graph::SDPA_backward_attributes()
.set_name("CUDNN_SDPA_BACKWARD")
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
auto O = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("O")
.set_dim(o.sizes().vec())
.set_stride(o.strides().vec()));
auto STATS = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Stats")
.set_dim(softmaxstats.sizes().vec())
.set_stride(softmaxstats.strides().vec())
.set_data_type(fe::DataType_t::FLOAT));
auto DO = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("DO")
.set_dim(dO.sizes().vec())
.set_stride(dO.strides().vec()));
if (dropout_probability != 0.0f) {
sdpa_backward_options.set_dropout(dropout_probability, Seed, Offset);
}
auto [DQ, DK, DV] =
mha_graph->sdpa_backward(Q, K, V, O, DO, STATS, sdpa_backward_options);
DQ->set_output(true)
.set_dim(std::vector<int64_t>(dQ.sizes().begin(), dQ.sizes().end()))
.set_stride(
std::vector<int64_t>(dQ.strides().begin(), dQ.strides().end()));
DK->set_output(true)
.set_dim(std::vector<int64_t>(dK.sizes().begin(), dK.sizes().end()))
.set_stride(
std::vector<int64_t>(dK.strides().begin(), dK.strides().end()));
DV->set_output(true)
.set_dim(std::vector<int64_t>(dV.sizes().begin(), dV.sizes().end()))
.set_stride(
std::vector<int64_t>(dV.strides().begin(), dV.strides().end()));
DQ->set_output(true).set_dim(dQ.sizes().vec()).set_stride(dQ.strides().vec());
DK->set_output(true).set_dim(dK.sizes().vec()).set_stride(dK.strides().vec());
DV->set_output(true).set_dim(dV.sizes().vec()).set_stride(dV.strides().vec());
AT_CUDNN_FRONTEND_CHECK(mha_graph->validate());
AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle));
AT_CUDNN_FRONTEND_CHECK(
@ -534,6 +520,7 @@ auto build_graph_and_tensors_backward(
std::move(Q),
std::move(K),
std::move(V),
std::move(bias),
std::move(attn_scale),
std::move(Seed),
std::move(Offset),
@ -559,6 +546,7 @@ void run_cudnn_SDP_fprop(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
Tensor& softmaxstats,
Tensor& o,
Tensor& dropoutseed,
@ -573,6 +561,11 @@ void run_cudnn_SDP_fprop(
softmaxstats = at::empty({b, h, s_q}, q.options().dtype(kFloat));
}
// do nothing if we got 0-element tensors
if (!q.numel() || !k.numel() || !v.numel()) {
return;
}
auto key = MHACacheKeyWrapper(
b,
h,
@ -583,6 +576,7 @@ void run_cudnn_SDP_fprop(
q,
k,
v,
attn_bias,
dropout_probability,
is_causal,
return_softmaxstats);
@ -605,13 +599,14 @@ void run_cudnn_SDP_fprop(
q,
k,
v,
attn_bias,
softmaxstats,
o,
dropoutseed,
dropoutoffset,
handle);
}
auto [mha_graph, Q, K, V, attn_scale, seed, offset, O, Stats] =
auto [mha_graph, Q, K, V, bias, attn_scale, seed, offset, O, Stats] =
graph_and_tensors_values;
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*>
variant_pack = {
@ -619,13 +614,15 @@ void run_cudnn_SDP_fprop(
{K, k.data_ptr()},
{V, v.data_ptr()},
{attn_scale, &scaling_factor},
//{bias, bias.data_ptr()},
{seed, dropoutseed.data_ptr()},
{offset, dropoutoffset.data_ptr()},
{O, o.data_ptr()}};
if (return_softmaxstats) {
variant_pack[Stats] = softmaxstats.data_ptr();
}
if (attn_bias.has_value()) {
variant_pack[bias.value()] = attn_bias.value().data_ptr();
}
auto workspace_size = mha_graph->get_workspace_size();
auto workspace_ptr =
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
@ -647,6 +644,7 @@ void run_cudnn_SDP_bprop(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
const Tensor& o,
const Tensor& dO,
const Tensor& softmaxstats,
@ -655,6 +653,12 @@ void run_cudnn_SDP_bprop(
Tensor& dV,
const Tensor& dropoutseed,
const Tensor& dropoutoffset) {
// do nothing if we got 0-element tensors
if (!q.numel() || !k.numel() || !v.numel() || !o.numel() || !dO.numel() ||
!softmaxstats.numel()) {
return;
}
Tensor dO_ = dO;
if (!dO.strides()[dO.strides().size() - 1]) {
TORCH_WARN(
@ -694,6 +698,7 @@ void run_cudnn_SDP_bprop(
q,
k,
v,
attn_bias,
dropout_probability,
is_causal,
true);
@ -715,6 +720,7 @@ void run_cudnn_SDP_bprop(
q,
k,
v,
attn_bias,
o,
dO_,
softmaxstats,
@ -726,8 +732,20 @@ void run_cudnn_SDP_bprop(
handle);
}
auto
[mha_graph, Q, K, V, attn_scale, Seed, Offset, O, Do, Stats, Dq, Dk, Dv] =
graph_and_tensors_backward_values;
[mha_graph,
Q,
K,
V,
bias,
attn_scale,
Seed,
Offset,
O,
Do,
Stats,
Dq,
Dk,
Dv] = graph_and_tensors_backward_values;
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*>
variant_pack = {// inputs
{Q, q.data_ptr()},
@ -746,6 +764,9 @@ void run_cudnn_SDP_bprop(
variant_pack[Seed] = dropoutseed.data_ptr();
variant_pack[Offset] = dropoutoffset.data_ptr();
}
if (attn_bias.has_value()) {
variant_pack[bias.value()] = attn_bias.value().data_ptr();
}
auto workspace_size = mha_graph->get_workspace_size();
auto workspace_ptr =
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);

View File

@ -18,6 +18,7 @@ void run_cudnn_SDP_fprop(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
Tensor& softmaxstats,
Tensor& o,
Tensor& dropoutseed,
@ -36,6 +37,7 @@ void run_cudnn_SDP_bprop(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
const Tensor& o,
const Tensor& dO,
const Tensor& softmaxstats,

Some files were not shown because too many files have changed in this diff Show More