Compare commits

...

40 Commits

Author SHA1 Message Date
c831fcb0b1 Update on "WIP: testing fix for #165177"
cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
2025-10-17 08:11:22 -07:00
20111eac3e WIP: testing fix for #165177
[ghstack-poisoned]
2025-10-16 20:41:52 -07:00
fcbde24c1c [ONNX] Remove common imports from torchlib (#165156)
The Rank and IsScalar functions are no longer used in the torchlib. Requires onnxscript v0.5.4

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165156
Approved by: https://github.com/Skylion007, https://github.com/cyyever
2025-10-17 03:25:34 +00:00
861cdb887b use statically_known_leq & *=2 instead of bound_sympy in persistent rblock (#165657)
While these should be equivalent, we've found instances where they are not, and an error was caused. update until we figure out underlying issue.

Differential Revision: [D84835898](https://our.internmc.facebook.com/intern/diff/D84835898)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165657
Approved by: https://github.com/bobrenjc93
2025-10-17 02:48:03 +00:00
3154482072 [CUDA][cuBLAS] Only xFail addmm with reduced precision reductions on non-RTX skus (#165379)
RTX Blackwells don't behave quite like their datacenter counterparts here

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165379
Approved by: https://github.com/Skylion007
2025-10-17 02:45:07 +00:00
9fccbdd4f0 Fix incorrect function signature in template (#165567)
Summary:
In https://github.com/pytorch/pytorch/pull/148305 we refactored the grid
argument out, but it's not reflected in our template.

Test Plan:
Included in commit.
python test/inductor/test_aot_inductor.py
AOTInductorTestABICompatibleGpu.test_cond_symint_input_disable_one_pass_cuda

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165567
Approved by: https://github.com/desertfire
2025-10-17 02:40:56 +00:00
7dabfb07cb [torchfuzz] add support for --stop-at-first-failure flag (#165529)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165529
Approved by: https://github.com/pianpwk
ghstack dependencies: #164749
2025-10-17 02:18:07 +00:00
d0add0be43 [torchfuzz] check in some more ignore regexes (#164749)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164749
Approved by: https://github.com/pianpwk
2025-10-17 02:18:07 +00:00
11e2084308 Revert "[Mem Snapshot] Add Metadata Field (#165490)"
This reverts commit 5b3ea758951558e7d9f681ae784acb57eaa07910.

Reverted https://github.com/pytorch/pytorch/pull/165490 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/165490#issuecomment-3413491091))
2025-10-17 02:01:53 +00:00
9726553653 [BE][Ez]: Use sys.executable instead of hardcoded Python (#165679)
Handles edgecase to ensure proper interpreter is called. Inspired by #165633
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165679
Approved by: https://github.com/FindHao
2025-10-17 01:07:40 +00:00
d82527b32a [Windows] Add AOTI cross-compilation CI (#165573)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165573
Approved by: https://github.com/malfet
ghstack dependencies: #165560
2025-10-17 01:05:35 +00:00
5d9b024276 Add mingw to docker (#165560)
Add mingw to `pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11` docker image to support AOTI cross-compilation

This PR will make docker container rebuild, and upgrade python version from 3.13.7 to 3.13.8. and it relies on https://github.com/pytorch/pytorch/pull/165667
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165560
Approved by: https://github.com/malfet
2025-10-17 00:47:01 +00:00
5b2afe4c5d Turn some const variables into constexpr in C++ code (#165401)
This PR checks the C++ code and turns some const variables into constexpr.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165401
Approved by: https://github.com/Skylion007
2025-10-17 00:40:11 +00:00
b2953f5643 [9/N] Apply ruff UP035 rule (#165515)
This is follow-up of #165214 to continue applying ruff UP035 rule to the code base.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165515
Approved by: https://github.com/Lucaskabela
2025-10-17 00:09:51 +00:00
470e2f61c3 Revert "[Fix] Use sys.executable instead of hardcoded python (#165633)"
This reverts commit 37f3ba274a8ccebc6b3409f52cf068a8b23617d4.

Reverted https://github.com/pytorch/pytorch/pull/165633 on behalf of https://github.com/malfet due to Looks like it broke test_collect_callgrind in slow workflows, see e0fe37fa68/1 ([comment](https://github.com/pytorch/pytorch/pull/165633#issuecomment-3413290813))
2025-10-17 00:06:40 +00:00
e0fe37fa68 [MPS] Move torch.cat impl to Metal (#165373)
After this change, all of the cases tested in [this performance measurement script](10de64c5ac/cat/perf0.py) take either roughly the same runtime or less.

Before:

```
idx: cpu time, mps time, speedup, op, args, kwargs
-----------------------------------------
0: 0.000857 ms, 0.016098 ms, 0.05, cat, [[tensor(shape[5, 5]), tensor(shape[5, 5])]], {'dim': -1}
1: 0.000858 ms, 0.014861 ms, 0.06, cat, [[tensor(shape[5, 5]), tensor(shape[5, 5])]], {'dim': 1}
2: 0.000806 ms, 0.015145 ms, 0.05, cat, [[tensor(shape[10, 5]), tensor(shape[5, 5])]], {'dim': 0}
3: 0.000829 ms, 0.015355 ms, 0.05, cat, [[tensor(shape[1, 2, 3]), tensor(shape[1, 2, 3])]], {'dim': -2}
4: 0.000591 ms, 0.000582 ms, 1.02, cat, [[tensor(shape[0]), tensor(shape[0])]], {'dim': 0}
5: 0.001076 ms, 0.022387 ms, 0.05, cat, [[tensor(shape[0]), tensor(shape[5, 5])]], {'dim': 1}
6: 0.000708 ms, 0.022300 ms, 0.03, cat, [[tensor(shape[0, 5]), tensor(shape[5, 5])]], {'dim': 0}
7: 0.000640 ms, 0.014367 ms, 0.04, cat, [[tensor(shape[1]), tensor(shape[1])]], {}
8: 0.000777 ms, 0.027506 ms, 0.03, cat, [[tensor(shape[2, 2, 2, 2])], 1], {}
9: 0.003383 ms, 0.269277 ms, 0.01, cat, "[[tensor(shape[3, 1, 2]), tensor(shape[3, 2, 2]), tensor(shape[3, 3, 2]), tensor(shape[3, 1, 2]), te...", {'dim': 1}
10: 0.526138 ms, 0.650852 ms, 0.81, cat, "[[tensor(shape[3, 1, 2]), tensor(shape[3, 2, 2]), tensor(shape[3, 3, 2]), tensor(shape[3, 1, 2]), te...", {'dim': 1}
11: 0.444091 ms, 0.628630 ms, 0.71, cat, "[[tensor(shape[1, 3, 2]), tensor(shape[2, 3, 2]), tensor(shape[3, 3, 2]), tensor(shape[1, 3, 2]), te...", {'dim': 0}
12: 2.011870 ms, 0.989525 ms, 2.03, cat, [[tensor(shape[1000000, 3, 2]), tensor(shape[1000000, 3, 2])]], {'dim': 0}
13: 3.100653 ms, 0.948178 ms, 3.27, cat, [[tensor(shape[3, 1000000, 2]), tensor(shape[3, 1000000, 2])]], {'dim': 1}
14: 3.112174 ms, 0.954174 ms, 3.26, cat, [[tensor(shape[3, 2, 1000000]), tensor(shape[3, 2, 1000000])]], {'dim': 2}
```

After:

```
idx: cpu time, mps time, speedup, op, args, kwargs
-----------------------------------------
0: 0.000790 ms, 0.013111 ms, 0.06, cat, [[tensor(shape[5, 5]), tensor(shape[5, 5])]], {'dim': -1}
1: 0.000800 ms, 0.014419 ms, 0.06, cat, [[tensor(shape[5, 5]), tensor(shape[5, 5])]], {'dim': 1}
2: 0.000748 ms, 0.010019 ms, 0.07, cat, [[tensor(shape[10, 5]), tensor(shape[5, 5])]], {'dim': 0}
3: 0.000767 ms, 0.010063 ms, 0.08, cat, [[tensor(shape[1, 2, 3]), tensor(shape[1, 2, 3])]], {'dim': -2}
4: 0.000591 ms, 0.000591 ms, 1.00, cat, [[tensor(shape[0]), tensor(shape[0])]], {'dim': 0}
5: 0.001220 ms, 0.009763 ms, 0.12, cat, [[tensor(shape[0]), tensor(shape[5, 5])]], {'dim': 1}
6: 0.000739 ms, 0.006203 ms, 0.12, cat, [[tensor(shape[0, 5]), tensor(shape[5, 5])]], {'dim': 0}
7: 0.000647 ms, 0.009905 ms, 0.07, cat, [[tensor(shape[1]), tensor(shape[1])]], {}
8: 0.000753 ms, 0.007818 ms, 0.10, cat, [[tensor(shape[2, 2, 2, 2])], 1], {}
9: 0.003823 ms, 0.192723 ms, 0.02, cat, "[[tensor(shape[3, 1, 2]), tensor(shape[3, 2, 2]), tensor(shape[3, 3, 2]), tensor(shape[3, 1, 2]), te...", {'dim': 1}
10: 0.576564 ms, 0.733920 ms, 0.79, cat, "[[tensor(shape[3, 1, 2]), tensor(shape[3, 2, 2]), tensor(shape[3, 3, 2]), tensor(shape[3, 1, 2]), te...", {'dim': 1}
11: 0.462957 ms, 0.692799 ms, 0.67, cat, "[[tensor(shape[1, 3, 2]), tensor(shape[2, 3, 2]), tensor(shape[3, 3, 2]), tensor(shape[1, 3, 2]), te...", {'dim': 0}
12: 2.017181 ms, 0.968345 ms, 2.08, cat, [[tensor(shape[1000000, 3, 2]), tensor(shape[1000000, 3, 2])]], {'dim': 0}
13: 3.203508 ms, 0.986382 ms, 3.25, cat, [[tensor(shape[3, 1000000, 2]), tensor(shape[3, 1000000, 2])]], {'dim': 1}
14: 3.181249 ms, 1.007773 ms, 3.16, cat, [[tensor(shape[3, 2, 1000000]), tensor(shape[3, 2, 1000000])]], {'dim': 2}
```

Fixes #165350
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165373
Approved by: https://github.com/kulinseth, https://github.com/malfet
2025-10-17 00:03:04 +00:00
d2c82bafb7 Revert "158232 Fix autocast cache incorrectly retaining no_grad state (#165068)"
This reverts commit 5daef30b26b794d237fbbc399c1d47ec0380200a.

Reverted https://github.com/pytorch/pytorch/pull/165068 on behalf of https://github.com/jeffdaily due to This broke ROCm CI. test/test_transformers.py::TestTransformersCUDA::test_transformerencoder_fastpath_use_torchscript_False_enable_nested_tensor_True_use_autocast_True_d_model_256_cuda [GH job link](https://github.com/pytorch/pytorch/actions/runs/18572589089/job/52952074008) [HUD commit link](5daef30b26) ([comment](https://github.com/pytorch/pytorch/pull/165068#issuecomment-3413184445))
2025-10-16 23:08:27 +00:00
98a488c9aa Start recording inductor provenance (#162669)
Summary:
This stores information on where fx graphs come from, which makes it
significantly easier to debug.

One outstanding question

1) I only stored the kernel stack traces, do we also want the node mappings?

Test Plan:
I wrote a explicit logging test which makes a module, fx traces it, compiles it, and makes sure the logging infomration shows up.

```
clr@devvm17763 ~/fbsource/fbcode/caffe2/test/dynamo
 % buck2 test @//mode/opt fbcode//caffe2/test/dynamo:test_dynamo -- test_utils

File changed: fbsource//xplat/caffe2/test/dynamo/test_utils.py
File changed: fbcode//caffe2/test/dynamo/test_utils.py
Buck UI: https://www.internalfb.com/buck2/528dea32-2416-4a62-a1ec-39f3c0efdd2e
Test UI: https://www.internalfb.com/intern/testinfra/testrun/13229324015574003
Network: Up: 0B  Down: 0B
Executing actions. Remaining     0/2
Command: test.
Time elapsed: 17.3s
Tests finished: Pass 16. Fail 0. Fatal 0. Skip 0. Build failure 0
```

Rollback Plan:

Differential Revision: D82037582

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162669
Approved by: https://github.com/yushangdi
2025-10-16 23:05:31 +00:00
5b3ea75895 [Mem Snapshot] Add Metadata Field (#165490)
Summary:
The implementation adds the ability to:

Set custom metadata strings that will be attached to all subsequent allocations
Clear or change the metadata at any point
View the metadata in memory snapshots via _dump_snapshot()

Test Plan: Added test in test_cuda.py and check manually in snapshot to see that metadata was added.

Differential Revision: D84654933

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165490
Approved by: https://github.com/yushangdi
2025-10-16 22:54:27 +00:00
556fc09a9f [DebugMode][1/N] refactor logs into _DebugCalls (#165376)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165376
Approved by: https://github.com/SherlockNoMad
2025-10-16 22:43:52 +00:00
ce109b3f79 Add torch.backends.mkldnn.is_acl_available() method (#165678)
That tells whether or not PyTorch was compiled with Arm Compute Library
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165678
Approved by: https://github.com/Skylion007, https://github.com/atalman, https://github.com/albanD
ghstack dependencies: #165583, #165584, #165676
2025-10-16 22:34:21 +00:00
4d833f859b [BE] [CI] Fix aarch64 arch checks (#165676)
Instead of relying on `TEST_CONFIG` environment variable  to contain `aarch64`, which is prone to errors,  use output of  `$(uname -m)` that is equal to `aarch64` on Linux ARM systems
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165676
Approved by: https://github.com/huydhn, https://github.com/atalman
ghstack dependencies: #165583, #165584
2025-10-16 22:19:53 +00:00
d7e275d4b4 [CI][CUDA] Add periodic b200 distributed job (#159323)
1. Run distributed job with B200 runner, periodically.
2. discovered generic distributed test issue that certain unit test hard-coded ranks, calling for require_exact_world_size(world_size) API instead of require_world_size(world_size).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159323
Approved by: https://github.com/eqy

Co-authored-by: Aidyn-A <aidyn.b.aitzhan@gmail.com>
2025-10-16 21:54:04 +00:00
d5db3aee0d [CI] Use 1-GPU runners for rocm-mi355.yml (#165658)
Should only need 1-GPU runners for rocm-mi355.yml since it runs `default` test config which only needs 1 GPU

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165658
Approved by: https://github.com/jeffdaily
2025-10-16 21:53:22 +00:00
5641de7b6b Add suppressions for _inductor/codegen (#165659)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165659
Approved by: https://github.com/oulgen
2025-10-16 21:37:37 +00:00
cbc08c8993 Add NEON acceleration for Vectorized<int[8|16|32|64> (#165273)
Summary:
Adding NEON specializations of Vectorized<T> for int8, int16, int32 and int64.

Correcness has been checked using test_ops.py and the comprehensive torch test

operator_benchmark_test.py has been enhanced by adding cases of bitwise operations, boolean ops and integer ops.
The benchmark, which uses the PyTorch API, shows significant enhancements in a wide variety of operations:

Before:

bitwise xor: 779.882us
boolean any: 636.209us
boolean all: 538.621us
integer mul: 304.457us
integer asr: 447.997us

After:

bitwise xor: 680.221us ---> 15% higher throughput
boolean any: 391.468us ---> 63% higher throughput
boolean all: 390.189us ---> 38% higher throughput
integer mul: 193.532us ---> 57% higher throughput
integer asr: 179.929us---> 149% higher throughput

Test Plan:
Correctness:

buck2 test @mode/opt //caffe2/test:test_ops
buck2 test @mode/opt //caffe2/test:torch
buck2 test @mode/opt //caffe2/test/distributed/launcher/fb:fb_run_test

Performance:

buck2 run mode/opt //caffe2/benchmarks/operator_benchmark/fb:operator_benchmark_test

Differential Revision: D84424638

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165273
Approved by: https://github.com/malfet
2025-10-16 21:35:13 +00:00
1a54d3333d [easy] Fix graph_capture in aot_joint_with_descriptors test (#165660)
when `with_export=True`, `aot_export_joint_with_descriptors` should take the graph produced by `_dynamo_graph_capture_for_export`

```
python test/functorch/test_aot_joint_with_descriptors.py -k test_preserve_annotate_simple
python test/functorch/test_aot_joint_with_descriptors.py -k test_preserve_annotate_flex_attention
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165660
Approved by: https://github.com/yushangdi
2025-10-16 21:10:11 +00:00
4c1c341fa0 FakeTensorMode shouldn't cache syms when tracing (#164718)
Improve FakeTensor cache to handle SymNode and tracing properly.

For now, when we're proxy tracing just don't bother caching operations that contain SymNodes in the output. The problem is that the proxy tracer relies on SymNode identity and our cache doesn't preserve that. It can be fixed (and I left some notes in _validate_symbolic_output_for_caching() how) but it's not worth it for now.

If we aren't proxy tracing then caching is fine.

Thus these changes:

1. Our cache key needs to include whether we were actively tracing or not - this way if we create a cache entry when we weren't tracing and then we try to use it when we ARE tracing it gets rerun.

2. If there's a SymNode in the output then bypass tracing.

3. Some general cleanup of the output validation - we were unnecessarily doing it as a two-step process when it could just be a single step (it's still two parts internally but only a single outer try/except).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164718
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #165266, #164717
2025-10-16 20:57:07 +00:00
5f21cc786a Teach ProxyTorchDispatchMode how to decompose sympy.Expr into known inputs (#164717)
In a training library we hit a weird conflict between dtensor, dynamic shapes, and proxy tensor.

The problem is occuring because in sharding_prop we use FakeTensors to compute an operation size (so we don't have to  use the full "real" data). We turn off proxy tracing while we're doing that because we don't want the FakeTensor ops to end up in the graph.  We then use that size when doing later operations.

Normally this is no problem - but when those sizes are dynamic shapes then we have a problem - the proxy tracer wants to track the provenance of all shape operations (`s1*s2`) but since tracing is disabled it doesn't see the operation and when we then use the result shape later on the proxy tracer gets all confused (because the SymNode appeared out of nowhere).

At first we were thinking to never disable shape tracing - but that caused a slew of other downstream problems (lots of code that actually needs the shape tracing to be disabled) so instead we enable having a "sym tracing override" and surgically when we disable proxy tracing we leave shape tracing enabled.

After this change the dtensor embedding is "fixed" but then runs afoul of a FakeTensor cache bug - which is fixed in the next PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164717
Approved by: https://github.com/bobrenjc93, https://github.com/ezyang
ghstack dependencies: #165266
2025-10-16 20:57:06 +00:00
e86942f422 minor proxy_tensor reorg (#165266)
Moving some code around in proxy_tensor in preparation for the next PR. There we
no actual changes (other than simple relabeling such as `self.tracer` ->
`tracer`):

- Move _compute_proxy() out of ProxyTorchDispatchMode.

- Give `sympy_expr_tracker` a structured type instead of `object`.

- Split SymNode registration out of ProxyTorchDispatchMode.__sym_dispatch__() so
  it can be reused.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165266
Approved by: https://github.com/ezyang, https://github.com/mlazos
2025-10-16 20:57:06 +00:00
2cd5fd1588 Enable local tensor mode on DTensor view ops test (#165596)
While enabling this test discovered lack of support for sub meshes. Added limited support
for sub meshes by properly computing rank coordinates for a given sub mesh. The implementation
follows similar approach to collectives. We infer all sub meshes for the given dimensions and
compute each rank's coordinates with respect to is sub mesh.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165596
Approved by: https://github.com/ezyang
2025-10-16 20:52:06 +00:00
7d0f872cb3 Use union syntax in torch/_inductor runtime and fx_passes (#165652)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165652
Approved by: https://github.com/aorenste
2025-10-16 20:51:59 +00:00
fb06e49ce8 Revert "[inductor] print 0.0 as 0 for triton (#164291)"
This reverts commit 99b32a6750bfd0cfe2bc84a47823e1da34802b7b.

Reverted https://github.com/pytorch/pytorch/pull/164291 on behalf of https://github.com/malfet due to Broke slow job, see aba8c43594/1  ([comment](https://github.com/pytorch/pytorch/pull/164291#issuecomment-3412768915))
2025-10-16 20:44:29 +00:00
27a98e6ae9 Revert "[DeviceMesh] Prefer using _layout over _mesh for all sorts of things (#165554)"
This reverts commit d61a9b88cf3be04a29c5a7d6e9622ae5e8d51de3.

Reverted https://github.com/pytorch/pytorch/pull/165554 on behalf of https://github.com/malfet due to Looks like it broke serialization test, see aba8c43594/1 ([comment](https://github.com/pytorch/pytorch/pull/165554#issuecomment-3412765681))
2025-10-16 20:41:37 +00:00
b10f463b1a Revert "[DeviceMesh] Introduce private constructor instead of _create_mesh_from_ranks (#165555)"
This reverts commit 99097b6d89c927c15180ff4683c38be01f9955f6.

Reverted https://github.com/pytorch/pytorch/pull/165555 on behalf of https://github.com/malfet due to Looks like it broke serialization test, see aba8c43594/1 ([comment](https://github.com/pytorch/pytorch/pull/165554#issuecomment-3412765681))
2025-10-16 20:41:37 +00:00
431c13cf61 Revert "[DeviceMesh] Simplify unflatten method (#165556)"
This reverts commit 86fd4fc23e697e275d37c36e3cbe521f156434fd.

Reverted https://github.com/pytorch/pytorch/pull/165556 on behalf of https://github.com/malfet due to Looks like it broke serialization test, see aba8c43594/1 ([comment](https://github.com/pytorch/pytorch/pull/165554#issuecomment-3412765681))
2025-10-16 20:41:37 +00:00
aead9270f5 12/n : Remove fbandroid_compiler_flags (#165558)
Summary:
Currently `get_c2_fbandroid_xplat_compiler_flags()` is reading the `caffe2.strip_glog` buckconfig which we want to get rid of.
This diff removes the `fbandroid_compiler_flags` arg and merges it with compiler_flags with a nested select and the select version of the method

The goal is to get rid of all the usages of `get_c2_fbandroid_xplat_compiler_flags()` so that we can get rid of the `caffe2.strip_glog` buckconfig

Test Plan: CI

bifferential Revision: D84626885

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165558
Approved by: https://github.com/malfet
2025-10-16 20:41:24 +00:00
9bf5b38c14 [Inductor][Triton][FP8] Refactor scaled_mm template to accept scaling mode (#164318)
Summary: Refactor `scaled_mm` Inductor template to support template choice based on scaling mode. This modification sets up the infrastructure for adding new templates based on new scaling modes, such as deepseek-style scaling (a follow-up diff), as new scaling modes (deepseek, block, group) scale before the accumulation (as opposed to per-tensor and per-row scaling, which apply scaling after accumulation). This modification also further enables Inductor to infer a scaling type based on the shape of the scaling tensors, which makes existing infrastructure more extensible to new scaling modes.

Test Plan:
```
TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/{opt,inplace} pytorch/tritonbench:run -- --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --m 256 --n 768 --k 512 --output="/home/jananisriram/personal/random_bench.csv" --scaling_rowwise --atol=20 --rtol=2 2>&1 | tee ~/personal/random.log
```

bifferential Revision: D83591083

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164318
Approved by: https://github.com/drisspg, https://github.com/slayton58
2025-10-16 20:40:45 +00:00
aba8c43594 Register var for MTIA (#165382)
Summary: Registers variance kernel

Reviewed By: srsuryadev

Differential Revision: D84546250

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165382
Approved by: https://github.com/malfet
2025-10-16 20:35:15 +00:00
37f3ba274a [Fix] Use sys.executable instead of hardcoded python (#165633)
Replace hardcoded "python" string with sys.executable to ensure correct Python interpreter is used. This fixes failures on systems with multiple Python runtimes or where "python" is not in PATH.

Similar to pytorch/pytorch#155918

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165633
Approved by: https://github.com/Skylion007
2025-10-16 20:26:10 +00:00
158 changed files with 3094 additions and 1044 deletions

View File

@ -113,6 +113,7 @@ case "$tag" in
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
TRITON=yes
INSTALL_MINGW=yes
;;
pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11)
CUDA_VERSION=13.0.0
@ -361,6 +362,7 @@ docker build \
--build-arg "OPENBLAS=${OPENBLAS:-}" \
--build-arg "SKIP_SCCACHE_INSTALL=${SKIP_SCCACHE_INSTALL:-}" \
--build-arg "SKIP_LLVM_SRC_BUILD_INSTALL=${SKIP_LLVM_SRC_BUILD_INSTALL:-}" \
--build-arg "INSTALL_MINGW=${INSTALL_MINGW:-}" \
-f $(dirname ${DOCKERFILE})/Dockerfile \
-t "$tmp_tag" \
"$@" \

View File

@ -0,0 +1,10 @@
#!/bin/bash
set -ex
# Install MinGW-w64 for Windows cross-compilation
apt-get update
apt-get install -y g++-mingw-w64-x86-64-posix
echo "MinGW-w64 installed successfully"
x86_64-w64-mingw32-g++ --version

View File

@ -20,7 +20,7 @@ pip_install \
pip_install coloredlogs packaging
pip_install onnxruntime==1.23.0
pip_install onnxscript==0.5.3
pip_install onnxscript==0.5.4
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/

View File

@ -103,6 +103,11 @@ COPY ci_commit_pins/torchbench.txt torchbench.txt
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt
ARG INSTALL_MINGW
COPY ./common/install_mingw.sh install_mingw.sh
RUN if [ -n "${INSTALL_MINGW}" ]; then bash ./install_mingw.sh; fi
RUN rm install_mingw.sh
ARG TRITON
ARG TRITON_CPU

View File

@ -485,6 +485,22 @@ test_inductor_aoti() {
/usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile
}
test_inductor_aoti_cross_compile_for_windows() {
TEST_REPORTS_DIR=$(pwd)/test/test-reports
mkdir -p "$TEST_REPORTS_DIR"
# Set WINDOWS_CUDA_HOME environment variable
WINDOWS_CUDA_HOME="$(pwd)/win-torch-wheel-extracted"
export WINDOWS_CUDA_HOME
echo "WINDOWS_CUDA_HOME is set to: $WINDOWS_CUDA_HOME"
echo "Contents:"
ls -lah "$(pwd)/win-torch-wheel-extracted/lib/x64/" || true
python test/inductor/test_aoti_cross_compile_windows.py -k compile --package-dir "$TEST_REPORTS_DIR" --win-torch-lib-dir "$(pwd)/win-torch-wheel-extracted/torch/lib"
}
test_inductor_cpp_wrapper_shard() {
if [[ -z "$NUM_TEST_SHARDS" ]]; then
echo "NUM_TEST_SHARDS must be defined to run a Python test shard"
@ -900,7 +916,7 @@ test_inductor_set_cpu_affinity(){
export LD_PRELOAD="$JEMALLOC_LIB":"$LD_PRELOAD"
export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"
if [[ "${TEST_CONFIG}" != *aarch64* ]]; then
if [[ "$(uname -m)" != "aarch64" ]]; then
# Use Intel OpenMP for x86
IOMP_LIB="$(dirname "$(which python)")/../lib/libiomp5.so"
export LD_PRELOAD="$IOMP_LIB":"$LD_PRELOAD"
@ -914,7 +930,7 @@ test_inductor_set_cpu_affinity(){
cores=$((cpus / thread_per_core))
# Set number of cores to 16 on aarch64 for performance runs
if [[ "${TEST_CONFIG}" == *aarch64* && $cores -gt 16 ]]; then
if [[ "$(uname -m)" == "aarch64" && $cores -gt 16 ]]; then
cores=16
fi
export OMP_NUM_THREADS=$cores
@ -1667,7 +1683,7 @@ if [[ "${TEST_CONFIG}" == *numpy_2* ]]; then
python -m pip install --pre numpy==2.0.2 scipy==1.13.1 numba==0.60.0
fi
python test/run_test.py --include dynamo/test_functions.py dynamo/test_unspec.py test_binary_ufuncs.py test_fake_tensor.py test_linalg.py test_numpy_interop.py test_tensor_creation_ops.py test_torch.py torch_np/test_basic.py
elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" != *perf_cpu_aarch64* ]]; then
elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" == 'default' ]]; then
test_linux_aarch64
elif [[ "${TEST_CONFIG}" == *backward* ]]; then
test_forward_backward_compatibility
@ -1718,6 +1734,8 @@ elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then
test_inductor_triton_cpu
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then
test_inductor_micro_benchmark
elif [[ "${TEST_CONFIG}" == *aoti_cross_compile_for_windows* ]]; then
test_inductor_aoti_cross_compile_for_windows
elif [[ "${TEST_CONFIG}" == *huggingface* ]]; then
install_torchvision
id=$((SHARD_NUMBER-1))

View File

@ -3,6 +3,7 @@ ciflow_tracking_issue: 64124
ciflow_push_tags:
- ciflow/b200
- ciflow/b200-symm-mem
- ciflow/b200-distributed
- ciflow/binaries
- ciflow/binaries_libtorch
- ciflow/binaries_wheel

View File

@ -224,6 +224,46 @@ jobs:
continue-on-error: true
uses: ./.github/actions/download-td-artifacts
- name: Download Windows torch wheel for cross-compilation
if: matrix.win_torch_wheel_artifact != ''
uses: seemethere/download-artifact-s3@1da556a7aa0a088e3153970611f6c432d58e80e6 # v4.2.0
with:
name: ${{ matrix.win_torch_wheel_artifact }}
path: win-torch-wheel
- name: Extract Windows wheel and setup CUDA libraries
if: matrix.win_torch_wheel_artifact != ''
shell: bash
run: |
set -x
# Find the wheel file
WHEEL_FILE=$(find win-torch-wheel -name "*.whl" -type f | head -n 1)
if [ -z "$WHEEL_FILE" ]; then
echo "Error: No wheel file found in win-torch-wheel directory"
exit 1
fi
echo "Found wheel file: $WHEEL_FILE"
# Unzip the wheel file
unzip -q "$WHEEL_FILE" -d win-torch-wheel-extracted
echo "Extracted wheel contents"
# Setup CUDA libraries (cuda.lib and cudart.lib) directory
mkdir -p win-torch-wheel-extracted/lib/x64
if [ -f "win-torch-wheel/cuda.lib" ]; then
mv win-torch-wheel/cuda.lib win-torch-wheel-extracted/lib/x64/
echo "Moved cuda.lib to win-torch-wheel-extracted/lib/x64/"
fi
if [ -f "win-torch-wheel/cudart.lib" ]; then
mv win-torch-wheel/cudart.lib win-torch-wheel-extracted/lib/x64/
echo "Moved cudart.lib to win-torch-wheel-extracted/lib/x64/"
fi
# Verify CUDA libraries are present
echo "CUDA libraries:"
ls -la win-torch-wheel-extracted/lib/x64/ || echo "No CUDA libraries found"
- name: Parse ref
id: parse-ref
run: .github/scripts/parse_ref.py

View File

@ -168,6 +168,31 @@ jobs:
run: |
.ci/pytorch/win-build.sh
# Collect Windows torch libs and CUDA libs for cross-compilation
- name: Collect Windows CUDA libs for cross-compilation
if: steps.build.outcome != 'skipped' && inputs.cuda-version != 'cpu'
shell: bash
run: |
set -ex
# Create directory structure if does not exist
mkdir -p /c/${{ github.run_id }}/build-results
# Copy CUDA libs
CUDA_PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${{ inputs.cuda-version }}"
if [ -f "${CUDA_PATH}/lib/x64/cuda.lib" ]; then
cp "${CUDA_PATH}/lib/x64/cuda.lib" /c/${{ github.run_id }}/build-results/
fi
if [ -f "${CUDA_PATH}/lib/x64/cudart.lib" ]; then
cp "${CUDA_PATH}/lib/x64/cudart.lib" /c/${{ github.run_id }}/build-results/
fi
# List collected files
echo "Collected CUDA libs:"
ls -lah /c/${{ github.run_id }}/build-results/*.lib
# Upload to github so that people can click and download artifacts
- name: Upload artifacts to s3
if: steps.build.outcome != 'skipped'

62
.github/workflows/b200-distributed.yml vendored Normal file
View File

@ -0,0 +1,62 @@
name: CI for distributed tests on B200
on:
pull_request:
paths:
- .github/workflows/b200-distributed.yml
workflow_dispatch:
push:
tags:
- ciflow/b200-distributed/*
schedule:
- cron: 46 8 * * * # about 1:46am PDT
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
get-label-type:
if: github.repository_owner == 'pytorch'
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200:
name: linux-jammy-cuda12.8-py3.10-gcc11-build-distributed-b200
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '10.0'
test-matrix: |
{ include: [
{ config: "distributed", shard: 1, num_shards: 2, runner: "linux.dgx.b200.8" },
{ config: "distributed", shard: 2, num_shards: 2, runner: "linux.dgx.b200.8" },
]}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc11-test-distributed-b200:
name: linux-jammy-cuda12.8-py3.10-gcc11-test-b200
uses: ./.github/workflows/_linux-test.yml
needs:
- linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200
with:
timeout-minutes: 1200
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.test-matrix }}
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
secrets: inherit

View File

@ -88,27 +88,27 @@ jobs:
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
test-matrix: |
{ include: [
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
]}
secrets: inherit

View File

@ -45,12 +45,12 @@ jobs:
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
]}
secrets: inherit

View File

@ -200,6 +200,23 @@ jobs:
cuda-arch-list: '8.0'
secrets: inherit
# Test cross-compiled models with Windows libs extracted from wheel
cross-compile-linux-test:
name: cross-compile-linux-test
uses: ./.github/workflows/_linux-test.yml
needs:
- linux-jammy-cuda12_8-py3_10-gcc11-build
- get-label-type
- win-vs2022-cuda12_8-py3-build
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }}
test-matrix: |
{ include: [
{ config: "aoti_cross_compile_for_windows", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", win_torch_wheel_artifact: "win-vs2022-cuda12.8-py3" },
]}
secrets: inherit
verify-cachebench-cpu-build:
name: verify-cachebench-cpu-build
uses: ./.github/workflows/_linux-build.yml

View File

@ -2,7 +2,6 @@
#include <mutex>
#include <ATen/CachedTensorUtils.h>
#include <c10/core/GradMode.h>
#include <c10/util/flat_hash_map.h>
namespace at::autocast {
@ -37,29 +36,10 @@ namespace {
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
using val_type = std::tuple<weakref_type, Tensor>;
// We maintain separate caches for gradient-enabled and gradient-disabled modes.
// This ensures that tensors cached in torch.no_grad() (with requires_grad=False)
// are not incorrectly reused in gradient-enabled contexts.
// This fixes issue #158232 while maintaining optimal performance for both modes.
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_enabled() {
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_enabled;
return cached_casts_grad_enabled;
ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts;
return cached_casts;
}
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_disabled() {
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_disabled;
return cached_casts_grad_disabled;
}
// Helper function to get the appropriate cache based on current gradient mode.
// This allows us to cache tensors separately for grad-enabled and grad-disabled contexts,
// preventing incorrect cache hits when gradient mode changes.
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
return at::GradMode::is_enabled() ?
get_cached_casts_grad_enabled() :
get_cached_casts_grad_disabled();
}
std::mutex cached_casts_mutex;
@ -106,9 +86,7 @@ thread_local bool cache_enabled = true;
void clear_cache() {
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
// Clear both caches to ensure consistent behavior regardless of current gradient mode
get_cached_casts_grad_enabled().clear();
get_cached_casts_grad_disabled().clear();
get_cached_casts().clear();
}
int increment_nesting() {
@ -143,11 +121,6 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_
if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) {
// Heuristic: Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves).
// See cached_casts declaration above for detailed strategy.
//
// We maintain separate caches for gradient-enabled and gradient-disabled modes
// (see get_cached_casts() above). This ensures correctness when mixing torch.no_grad()
// with torch.autocast(), while maintaining optimal performance for both training and inference.
// This fixes issue #158232 without any performance regression.
bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) &&
arg.scalar_type() == at::kFloat && arg.requires_grad() &&
arg.is_leaf() && !arg.is_view() && cache_enabled &&

View File

@ -229,10 +229,10 @@ private:
}
static const uint32_t kPhilox10A = 0x9E3779B9;
static const uint32_t kPhilox10B = 0xBB67AE85;
static const uint32_t kPhiloxSA = 0xD2511F53;
static const uint32_t kPhiloxSB = 0xCD9E8D57;
static constexpr uint32_t kPhilox10A = 0x9E3779B9;
static constexpr uint32_t kPhilox10B = 0xBB67AE85;
static constexpr uint32_t kPhiloxSA = 0xD2511F53;
static constexpr uint32_t kPhiloxSB = 0xCD9E8D57;
};
typedef philox_engine Philox4_32;

View File

@ -8,6 +8,7 @@
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
#include <ATen/cpu/vec/vec128/vec128_int_aarch64.h>
#endif
#include <ATen/cpu/vec/vec128/vec128_convert.h>

View File

@ -0,0 +1,794 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/macros/Macros.h>
#include <c10/util/irange.h>
namespace at::vec {
// Note [CPU_CAPABILITY namespace]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// This header, and all of its subheaders, will be compiled with
// different architecture flags for each supported set of vector
// intrinsics. So we need to make sure they aren't inadvertently
// linked together. We do this by declaring objects in an `inline
// namespace` which changes the name mangling, but can still be
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
#define VEC_INT_NEON_TEMPLATE(vl, bit) \
template <> \
struct is_vec_specialized_for<int##bit##_t> : std::bool_constant<true> {}; \
\
template <> \
class Vectorized<int##bit##_t> { \
using neon_type = int##bit##x##vl##_t; \
\
private: \
neon_type values; \
\
public: \
using value_type = int##bit##_t; \
using size_type = int; \
static constexpr size_type size() { \
return vl; \
} \
Vectorized() { \
values = vdupq_n_s##bit(0); \
} \
Vectorized(neon_type v) : values(v) {} \
Vectorized(int##bit##_t val); \
template < \
typename... Args, \
typename = std::enable_if_t<(sizeof...(Args) == size())>> \
Vectorized(Args... vals) { \
__at_align__ int##bit##_t buffer[size()] = {vals...}; \
values = vld1q_s##bit(buffer); \
} \
operator neon_type() const { \
return values; \
} \
static Vectorized<int##bit##_t> loadu( \
const void* ptr, \
int64_t count = size()); \
void store(void* ptr, int64_t count = size()) const; \
template <int64_t mask> \
static Vectorized<int##bit##_t> blend( \
const Vectorized<int##bit##_t>& a, \
const Vectorized<int##bit##_t>& b); \
static Vectorized<int##bit##_t> blendv( \
const Vectorized<int##bit##_t>& a, \
const Vectorized<int##bit##_t>& b, \
const Vectorized<int##bit##_t>& mask_) { \
return vbslq_s##bit(vreinterpretq_u##bit##_s##bit(mask_.values), b, a); \
} \
template <typename step_t> \
static Vectorized<int##bit##_t> arange( \
value_type base = 0, \
step_t step = static_cast<step_t>(1)); \
static Vectorized<int##bit##_t> set( \
const Vectorized<int##bit##_t>& a, \
const Vectorized<int##bit##_t>& b, \
int64_t count = size()); \
const int##bit##_t& operator[](int idx) const = delete; \
int##bit##_t& operator[](int idx) = delete; \
Vectorized<int##bit##_t> abs() const { \
return vabsq_s##bit(values); \
} \
Vectorized<int##bit##_t> real() const { \
return values; \
} \
Vectorized<int##bit##_t> imag() const { \
return vdupq_n_s##bit(0); \
} \
Vectorized<int##bit##_t> conj() const { \
return values; \
} \
Vectorized<int##bit##_t> neg() const { \
return vnegq_s##bit(values); \
} \
int##bit##_t reduce_add() const { \
return vaddvq_s##bit(values); \
} \
int##bit##_t reduce_max() const; \
Vectorized<int##bit##_t> operator==( \
const Vectorized<int##bit##_t>& other) const { \
return Vectorized<value_type>( \
vreinterpretq_s##bit##_u##bit(vceqq_s##bit(values, other.values))); \
} \
Vectorized<int##bit##_t> operator!=( \
const Vectorized<int##bit##_t>& other) const; \
Vectorized<int##bit##_t> operator<( \
const Vectorized<int##bit##_t>& other) const { \
return Vectorized<value_type>( \
vreinterpretq_s##bit##_u##bit(vcltq_s##bit(values, other.values))); \
} \
Vectorized<int##bit##_t> operator<=( \
const Vectorized<int##bit##_t>& other) const { \
return Vectorized<value_type>( \
vreinterpretq_s##bit##_u##bit(vcleq_s##bit(values, other.values))); \
} \
Vectorized<int##bit##_t> operator>( \
const Vectorized<int##bit##_t>& other) const { \
return Vectorized<value_type>( \
vreinterpretq_s##bit##_u##bit(vcgtq_s##bit(values, other.values))); \
} \
Vectorized<int##bit##_t> operator>=( \
const Vectorized<int##bit##_t>& other) const { \
return Vectorized<value_type>( \
vreinterpretq_s##bit##_u##bit(vcgeq_s##bit(values, other.values))); \
} \
Vectorized<int##bit##_t> eq(const Vectorized<int##bit##_t>& other) const; \
Vectorized<int##bit##_t> ne(const Vectorized<int##bit##_t>& other) const; \
Vectorized<int##bit##_t> gt(const Vectorized<int##bit##_t>& other) const; \
Vectorized<int##bit##_t> ge(const Vectorized<int##bit##_t>& other) const; \
Vectorized<int##bit##_t> lt(const Vectorized<int##bit##_t>& other) const; \
Vectorized<int##bit##_t> le(const Vectorized<int##bit##_t>& other) const; \
}; \
template <> \
Vectorized<int##bit##_t> inline operator+( \
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
return vaddq_s##bit(a, b); \
} \
template <> \
Vectorized<int##bit##_t> inline operator-( \
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
return vsubq_s##bit(a, b); \
} \
template <> \
Vectorized<int##bit##_t> inline operator&( \
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
return vandq_s##bit(a, b); \
} \
template <> \
Vectorized<int##bit##_t> inline operator|( \
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
return vorrq_s##bit(a, b); \
} \
template <> \
Vectorized<int##bit##_t> inline operator^( \
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
return veorq_s##bit(a, b); \
} \
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::eq( \
const Vectorized<int##bit##_t>& other) const { \
return (*this == other) & Vectorized<int##bit##_t>(1); \
} \
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ne( \
const Vectorized<int##bit##_t>& other) const { \
return (*this != other) & Vectorized<int##bit##_t>(1); \
} \
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::gt( \
const Vectorized<int##bit##_t>& other) const { \
return (*this > other) & Vectorized<int##bit##_t>(1); \
} \
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ge( \
const Vectorized<int##bit##_t>& other) const { \
return (*this >= other) & Vectorized<int##bit##_t>(1); \
} \
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::lt( \
const Vectorized<int##bit##_t>& other) const { \
return (*this < other) & Vectorized<int##bit##_t>(1); \
} \
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::le( \
const Vectorized<int##bit##_t>& other) const { \
return (*this <= other) & Vectorized<int##bit##_t>(1); \
}
VEC_INT_NEON_TEMPLATE(2, 64)
VEC_INT_NEON_TEMPLATE(4, 32)
VEC_INT_NEON_TEMPLATE(8, 16)
VEC_INT_NEON_TEMPLATE(16, 8)
inline int32_t Vectorized<int32_t>::reduce_max() const {
return vmaxvq_s32(values);
}
inline int16_t Vectorized<int16_t>::reduce_max() const {
return vmaxvq_s16(values);
}
inline int8_t Vectorized<int8_t>::reduce_max() const {
return vmaxvq_s8(values);
}
template <>
Vectorized<int32_t> inline operator*(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b) {
return vmulq_s32(a, b);
}
template <>
Vectorized<int16_t> inline operator*(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b) {
return vmulq_s16(a, b);
}
template <>
Vectorized<int8_t> inline operator*(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& b) {
return vmulq_s8(a, b);
}
template <>
inline Vectorized<int64_t> operator~(const Vectorized<int64_t>& a) {
int64x2_t val = a;
return ~val;
}
template <>
inline Vectorized<int32_t> operator~(const Vectorized<int32_t>& a) {
return vmvnq_s32(a);
}
template <>
inline Vectorized<int16_t> operator~(const Vectorized<int16_t>& a) {
return vmvnq_s16(a);
}
template <>
inline Vectorized<int8_t> operator~(const Vectorized<int8_t>& a) {
return vmvnq_s8(a);
}
inline Vectorized<int64_t> Vectorized<int64_t>::operator!=(
const Vectorized<int64_t>& other) const {
return ~(*this == other);
}
inline Vectorized<int32_t> Vectorized<int32_t>::operator!=(
const Vectorized<int32_t>& other) const {
return ~(*this == other);
}
inline Vectorized<int16_t> Vectorized<int16_t>::operator!=(
const Vectorized<int16_t>& other) const {
return ~(*this == other);
}
inline Vectorized<int8_t> Vectorized<int8_t>::operator!=(
const Vectorized<int8_t>& other) const {
return ~(*this == other);
}
template <>
Vectorized<int32_t> inline minimum(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b) {
return vminq_s32(a, b);
}
template <>
Vectorized<int16_t> inline minimum(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b) {
return vminq_s16(a, b);
}
template <>
Vectorized<int8_t> inline minimum(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& b) {
return vminq_s8(a, b);
}
template <>
Vectorized<int32_t> inline maximum(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b) {
return vmaxq_s32(a, b);
}
template <>
Vectorized<int16_t> inline maximum(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b) {
return vmaxq_s16(a, b);
}
template <>
Vectorized<int8_t> inline maximum(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& b) {
return vmaxq_s8(a, b);
}
template <int64_t mask>
Vectorized<int64_t> Vectorized<int64_t>::blend(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b) {
// Build an array of flags: each bit of element is 1 if the corresponding bit
// in 'mask' is set, 0 otherwise.
uint64x2_t maskArray = {
(mask & 1LL) ? 0xFFFFFFFFFFFFFFFF : 0,
(mask & 2LL) ? 0xFFFFFFFFFFFFFFFF : 0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_s64(maskArray, b.values, a.values);
}
template <int64_t mask>
Vectorized<int32_t> Vectorized<int32_t>::blend(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b) {
// Build an array of flags: each bit of element is 1 if the corresponding bit
// in 'mask' is set, 0 otherwise.
uint32x4_t maskArray = {
(mask & 1LL) ? 0xFFFFFFFF : 0,
(mask & 2LL) ? 0xFFFFFFFF : 0,
(mask & 4LL) ? 0xFFFFFFFF : 0,
(mask & 8LL) ? 0xFFFFFFFF : 0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_s32(maskArray, b.values, a.values);
}
template <int64_t mask>
Vectorized<int16_t> Vectorized<int16_t>::blend(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b) {
// Build an array of flags: each bit of element is 1 if the corresponding bit
// in 'mask' is set, 0 otherwise.
uint16x8_t maskArray = {
(mask & 1LL) ? 0xFFFF : 0,
(mask & 2LL) ? 0xFFFF : 0,
(mask & 4LL) ? 0xFFFF : 0,
(mask & 8LL) ? 0xFFFF : 0,
(mask & 16LL) ? 0xFFFF : 0,
(mask & 32LL) ? 0xFFFF : 0,
(mask & 64LL) ? 0xFFFF : 0,
(mask & 128LL) ? 0xFFFF : 0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_s16(maskArray, b.values, a.values);
}
template <int64_t mask>
Vectorized<int8_t> Vectorized<int8_t>::blend(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& b) {
// Build an array of flags: each bit of element is 1 if the corresponding bit
// in 'mask' is set, 0 otherwise.
uint8x16_t maskArray = {
(mask & 1LL) ? 0xFF : 0,
(mask & 2LL) ? 0xFF : 0,
(mask & 4LL) ? 0xFF : 0,
(mask & 8LL) ? 0xFF : 0,
(mask & 16LL) ? 0xFF : 0,
(mask & 32LL) ? 0xFF : 0,
(mask & 64LL) ? 0xFF : 0,
(mask & 128LL) ? 0xFF : 0,
(mask & 256LL) ? 0xFF : 0,
(mask & 512LL) ? 0xFF : 0,
(mask & 1024LL) ? 0xFF : 0,
(mask & 2048LL) ? 0xFF : 0,
(mask & 4096LL) ? 0xFF : 0,
(mask & 8192LL) ? 0xFF : 0,
(mask & 16384LL) ? 0xFF : 0,
(mask & 32768LL) ? 0xFF : 0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_s8(maskArray, b.values, a.values);
}
#define VEC_INT_NEON_OPS(vl, bit) \
inline Vectorized<int##bit##_t>::Vectorized(int##bit##_t val) { \
values = vdupq_n_s##bit(val); \
} \
inline Vectorized<int##bit##_t> Vectorized<int##bit##_t>::loadu( \
const void* ptr, int64_t count) { \
if (count == size()) { \
return vld1q_s##bit(reinterpret_cast<const int##bit##_t*>(ptr)); \
} else { \
__at_align__ int##bit##_t tmp_values[size()]; \
for (const auto i : c10::irange(size())) { \
tmp_values[i] = 0; \
} \
std::memcpy( \
tmp_values, \
reinterpret_cast<const int##bit##_t*>(ptr), \
count * sizeof(int##bit##_t)); \
return vld1q_s##bit(reinterpret_cast<const int##bit##_t*>(tmp_values)); \
} \
} \
inline void Vectorized<int##bit##_t>::store(void* ptr, int64_t count) \
const { \
if (count == size()) { \
vst1q_s##bit(reinterpret_cast<int##bit##_t*>(ptr), values); \
} else { \
int##bit##_t tmp_values[size()]; \
vst1q_s##bit(reinterpret_cast<int##bit##_t*>(tmp_values), values); \
std::memcpy(ptr, tmp_values, count * sizeof(int##bit##_t)); \
} \
}
VEC_INT_NEON_OPS(2, 64)
VEC_INT_NEON_OPS(4, 32)
VEC_INT_NEON_OPS(8, 16)
VEC_INT_NEON_OPS(16, 8)
template <>
Vectorized<int64_t> inline operator*(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b) {
int64x2_t x = a;
int64x2_t y = b;
return x * y;
}
template <>
Vectorized<int64_t> inline operator/(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b) {
int64x2_t x = a;
int64x2_t y = b;
return x / y;
}
template <>
Vectorized<int32_t> inline operator/(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b) {
int32x4_t x = a;
int32x4_t y = b;
return x / y;
}
inline int64_t Vectorized<int64_t>::reduce_max() const {
return std::max(values[0], values[1]);
}
template <>
Vectorized<int64_t> inline minimum(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b) {
int64x2_t x = a;
int64x2_t y = b;
return {std::min(x[0], y[0]), std::min(x[1], y[1])};
}
template <>
Vectorized<int64_t> inline maximum(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b) {
int64x2_t x = a;
int64x2_t y = b;
return {std::max(x[0], y[0]), std::max(x[1], y[1])};
}
template <typename step_t>
inline Vectorized<int64_t> Vectorized<int64_t>::arange(
int64_t base,
step_t step) {
const Vectorized<int64_t> base_vec(base);
const Vectorized<int64_t> step_vec(step);
const int64x2_t step_sizes = {0, 1};
return base_vec.values + step_sizes * step_vec.values;
}
template <typename step_t>
inline Vectorized<int32_t> Vectorized<int32_t>::arange(
int32_t base,
step_t step) {
const Vectorized<int32_t> base_vec(base);
const Vectorized<int32_t> step_vec(step);
const int32x4_t step_sizes = {0, 1, 2, 3};
return vmlaq_s32(base_vec, step_sizes, step_vec);
}
template <typename step_t>
inline Vectorized<int16_t> Vectorized<int16_t>::arange(
int16_t base,
step_t step) {
const Vectorized<int16_t> base_vec(base);
const Vectorized<int16_t> step_vec(step);
const int16x8_t step_sizes = {0, 1, 2, 3, 4, 5, 6, 7};
return vmlaq_s16(base_vec, step_sizes, step_vec);
}
template <typename step_t>
inline Vectorized<int8_t> Vectorized<int8_t>::arange(int8_t base, step_t step) {
const Vectorized<int8_t> base_vec(base);
const Vectorized<int8_t> step_vec(step);
const int8x16_t step_sizes = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
return vmlaq_s8(base_vec, step_sizes, step_vec);
}
template <>
Vectorized<int64_t> inline operator>>(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b) {
int64x2_t x = a;
int64x2_t y = b;
uint64x2_t u = vreinterpretq_u64_s64(y);
uint64x2_t z = {std::min(u[0], (uint64_t)63), std::min(u[1], (uint64_t)63)};
return x >> vreinterpretq_s64_u64(z);
}
template <>
Vectorized<int32_t> inline operator>>(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b) {
int32x4_t x = a;
int32x4_t y = b;
uint32x4_t bound = vdupq_n_u32(31);
uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound);
return x >> vreinterpretq_s32_u32(z);
}
template <>
Vectorized<int16_t> inline operator>>(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b) {
int16x8_t x = a;
int16x8_t y = b;
uint16x8_t bound = vdupq_n_u16(15);
uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound);
return x >> vreinterpretq_s16_u16(z);
}
template <>
Vectorized<int8_t> inline operator>>(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& b) {
int8x16_t x = a;
int8x16_t y = b;
uint8x16_t bound = vdupq_n_u8(7);
int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound));
return x >> z;
}
template <>
Vectorized<int64_t> inline operator<<(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b) {
int64x2_t y = b;
uint64x2_t u = vreinterpretq_u64_s64(y);
uint64x2_t z = {std::min(u[0], (uint64_t)64), std::min(u[1], (uint64_t)64)};
return vshlq_s64(a, vreinterpretq_s64_u64(z));
}
template <>
Vectorized<int32_t> inline operator<<(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b) {
int32x4_t y = b;
uint32x4_t bound = vdupq_n_u32(32);
uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound);
return vshlq_s32(a, vreinterpretq_s32_u32(z));
}
template <>
Vectorized<int16_t> inline operator<<(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b) {
int16x8_t y = b;
uint16x8_t bound = vdupq_n_u16(16);
uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound);
return vshlq_s16(a, vreinterpretq_s16_u16(z));
}
template <>
Vectorized<int8_t> inline operator<<(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& b) {
int8x16_t y = b;
uint8x16_t bound = vdupq_n_u8(8);
int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound));
return vshlq_s8(a, z);
}
inline Vectorized<int64_t> Vectorized<int64_t>::set(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b,
int64_t count) {
if (count == 0) {
return a;
} else if (count >= 2) {
return b;
} else {
int64x2_t c = {b.values[0], a.values[1]};
return c;
}
}
inline Vectorized<int32_t> Vectorized<int32_t>::set(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b,
int64_t count) {
if (count == 0) {
return a;
} else if (count >= 4) {
return b;
} else {
// Build an array of flags: each bit of element is 1 if the corresponding
// bit in 'mask' is set, 0 otherwise.
uint32x4_t maskArray = {
(count >= 1LL) ? 0xFFFFFFFF : 0,
(count >= 2LL) ? 0xFFFFFFFF : 0,
(count >= 3LL) ? 0xFFFFFFFF : 0,
0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_s32(maskArray, b.values, a.values);
}
}
inline Vectorized<int16_t> Vectorized<int16_t>::set(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b,
int64_t count) {
if (count == 0) {
return a;
} else if (count >= 8) {
return b;
} else {
// Build an array of flags: each bit of element is 1 if the corresponding
// bit in 'mask' is set, 0 otherwise.
uint16x8_t maskArray = {
static_cast<uint16_t>((count >= 1LL) ? 0xFFFF : 0),
static_cast<uint16_t>((count >= 2LL) ? 0xFFFF : 0),
static_cast<uint16_t>((count >= 3LL) ? 0xFFFF : 0),
static_cast<uint16_t>((count >= 4LL) ? 0xFFFF : 0),
static_cast<uint16_t>((count >= 5LL) ? 0xFFFF : 0),
static_cast<uint16_t>((count >= 6LL) ? 0xFFFF : 0),
static_cast<uint16_t>((count >= 7LL) ? 0xFFFF : 0),
0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_s16(maskArray, b.values, a.values);
}
}
inline Vectorized<int8_t> Vectorized<int8_t>::set(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& b,
int64_t count) {
if (count == 0) {
return a;
} else if (count >= 16) {
return b;
} else {
// Build an array of flags: each bit of element is 1 if the corresponding
// bit in 'mask' is set, 0 otherwise.
uint8x16_t maskArray = {
static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0),
0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_s8(maskArray, b.values, a.values);
}
}
template <>
Vectorized<int16_t> inline operator/(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b) {
Vectorized<int32_t> highBitsA = vmovl_high_s16(a);
Vectorized<int32_t> highBitsB = vmovl_high_s16(b);
Vectorized<int32_t> lowBitsA = vmovl_s16(vget_low_s16(a));
Vectorized<int32_t> lowBitsB = vmovl_s16(vget_low_s16(b));
int32x4_t highBitsResult = highBitsA / highBitsB;
int32x4_t lowBitsResult = lowBitsA / lowBitsB;
return vuzp1q_s16(
vreinterpretq_s16_s32(lowBitsResult),
vreinterpretq_s16_s32(highBitsResult));
}
template <>
Vectorized<int8_t> inline operator/(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& b) {
Vectorized<int16_t> highBitsA = vmovl_high_s8(a);
Vectorized<int16_t> highBitsB = vmovl_high_s8(b);
Vectorized<int16_t> lowBitsA = vmovl_s8(vget_low_s8(a));
Vectorized<int16_t> lowBitsB = vmovl_s8(vget_low_s8(b));
int16x8_t highBitsResult = highBitsA / highBitsB;
int16x8_t lowBitsResult = lowBitsA / lowBitsB;
return vuzp1q_s8(
vreinterpretq_s8_s16(lowBitsResult),
vreinterpretq_s8_s16(highBitsResult));
}
template <>
Vectorized<int64_t> inline clamp(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& min,
const Vectorized<int64_t>& max) {
return minimum(max, maximum(min, a));
}
template <>
Vectorized<int32_t> inline clamp(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& min,
const Vectorized<int32_t>& max) {
return minimum(max, maximum(min, a));
}
template <>
Vectorized<int16_t> inline clamp(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& min,
const Vectorized<int16_t>& max) {
return minimum(max, maximum(min, a));
}
template <>
Vectorized<int8_t> inline clamp(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& min,
const Vectorized<int8_t>& max) {
return minimum(max, maximum(min, a));
}
template <>
Vectorized<int64_t> inline clamp_max(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& max) {
return minimum(max, a);
}
template <>
Vectorized<int32_t> inline clamp_max(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& max) {
return minimum(max, a);
}
template <>
Vectorized<int16_t> inline clamp_max(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& max) {
return minimum(max, a);
}
template <>
Vectorized<int8_t> inline clamp_max(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& max) {
return minimum(max, a);
}
template <>
Vectorized<int64_t> inline clamp_min(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& min) {
return maximum(min, a);
}
template <>
Vectorized<int32_t> inline clamp_min(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& min) {
return maximum(min, a);
}
template <>
Vectorized<int16_t> inline clamp_min(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& min) {
return maximum(min, a);
}
template <>
Vectorized<int8_t> inline clamp_min(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& min) {
return maximum(min, a);
}
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -1377,7 +1377,7 @@ Vectorized<c10::quint8> inline maximum(
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
at::vec::Vectorized<int8_t> src) {
auto s8x8 = vld1_s8(src.operator const int8_t*());
auto s8x8 = vget_low_s8(src);
auto s16x8 = vmovl_s8(s8x8);
auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8));
@ -1402,7 +1402,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
Vectorized<float> inline convert_int8_half_register_to_float(
at::vec::Vectorized<int8_t> src) {
auto s8x8 = vld1_s8(src.operator const int8_t*());
auto s8x8 = vget_low_s8(src);
auto s16x8 = vmovl_s8(s8x8);
auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8));

View File

@ -325,9 +325,9 @@ uint64_t CUDAGeneratorImpl::seed() {
*/
c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
// The RNG state comprises the seed, and an offset used for Philox.
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(int64_t);
static const size_t total_size = seed_size + offset_size;
constexpr size_t seed_size = sizeof(uint64_t);
constexpr size_t offset_size = sizeof(int64_t);
constexpr size_t total_size = seed_size + offset_size;
auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
auto rng_state = state_tensor.data_ptr<uint8_t>();
@ -346,9 +346,9 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
* and size of the internal state.
*/
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(int64_t);
static const size_t total_size = seed_size + offset_size;
constexpr size_t seed_size = sizeof(uint64_t);
constexpr size_t offset_size = sizeof(int64_t);
constexpr size_t total_size = seed_size + offset_size;
detail::check_rng_state(new_state);

View File

@ -240,8 +240,8 @@ TORCH_META_FUNC(gelu_backward) (
namespace at::native {
static const double SELU_ALPHA = 1.6732632423543772848170429916717;
static const double SELU_SCALE = 1.0507009873554804934193349852946;
static constexpr double SELU_ALPHA = 1.6732632423543772848170429916717;
static constexpr double SELU_SCALE = 1.0507009873554804934193349852946;
DEFINE_DISPATCH(elu_stub);
DEFINE_DISPATCH(elu_backward_stub);

View File

@ -286,7 +286,7 @@ template void scal_fast_path<scalar_t>(int *n, scalar_t *a, scalar_t *x, int *in
#if AT_BUILD_WITH_BLAS()
template <>
bool scal_use_fast_path<double>(int64_t n, int64_t incx) {
auto intmax = std::numeric_limits<int>::max();
auto constexpr intmax = std::numeric_limits<int>::max();
return n <= intmax && incx <= intmax;
}
@ -315,7 +315,7 @@ bool gemv_use_fast_path<float>(
int64_t incx,
[[maybe_unused]] float beta,
int64_t incy) {
auto intmax = std::numeric_limits<int>::max();
auto constexpr intmax = std::numeric_limits<int>::max();
return (m <= intmax) && (n <= intmax) && (lda <= intmax) &&
(incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax);
}

View File

@ -127,7 +127,7 @@ C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, unifor
template<typename scalar_t>
C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
const static scalar_t kTailValues[] = {
constexpr static scalar_t kTailValues[] = {
0.0810614667953272,
0.0413406959554092,
0.0276779256849983,
@ -139,7 +139,7 @@ C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
0.00925546218271273,
0.00833056343336287
};
if (k <= 9) {
if (k <= sizeof(kTailValues)/sizeof(scalar_t)) {
return kTailValues[static_cast<size_t>(k)];
}
scalar_t kp1sq = (k + 1) * (k + 1);

View File

@ -581,7 +581,7 @@ scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M,
template <typename scalar_t>
static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
// lanczos approximation
static const scalar_t lanczos_sum_expg_scaled_num[13] = {
static constexpr scalar_t lanczos_sum_expg_scaled_num[13] = {
0.006061842346248906525783753964555936883222,
0.5098416655656676188125178644804694509993,
19.51992788247617482847860966235652136208,
@ -596,7 +596,7 @@ static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
103794043.1163445451906271053616070238554,
56906521.91347156388090791033559122686859
};
static const scalar_t lanczos_sum_expg_scaled_denom[13] = {
static constexpr scalar_t lanczos_sum_expg_scaled_denom[13] = {
1.,
66.,
1925.,
@ -712,7 +712,7 @@ static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
template <typename scalar_t>
static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) {
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
static const scalar_t d[25][25] =
static constexpr scalar_t d[25][25] =
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2,
1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4,
3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6,

View File

@ -62,7 +62,7 @@
#include <utility>
#include <vector>
static const int MIOPEN_DIM_MAX = 5;
static constexpr int MIOPEN_DIM_MAX = 5;
namespace at::meta {

View File

@ -1038,7 +1038,7 @@ struct HelperInterpNearest : public HelperInterpBase {
// We keep this structure for BC and consider as deprecated.
// See HelperInterpNearestExact as replacement
static const int interp_size = 1;
static constexpr int interp_size = 1;
static inline void init_indices_weights(
at::ScalarType output_type,
@ -1155,7 +1155,7 @@ struct HelperInterpNearestExact : public HelperInterpNearest {
struct HelperInterpLinear : public HelperInterpBase {
static const int interp_size = 2;
static constexpr int interp_size = 2;
// Compute indices and weights for each interpolated dimension
// indices_weights = {
@ -1275,7 +1275,7 @@ struct HelperInterpLinear : public HelperInterpBase {
struct HelperInterpCubic : public HelperInterpBase {
static const int interp_size = 4;
static constexpr int interp_size = 4;
// Compute indices and weights for each interpolated dimension
// indices_weights = {

View File

@ -249,7 +249,7 @@ __global__ void max_pool_forward_nhwc(
}
static const int BLOCK_THREADS = 256;
static constexpr int BLOCK_THREADS = 256;
template <typename scalar_t, typename accscalar_t>
#if defined (USE_ROCM)

View File

@ -36,9 +36,9 @@ namespace at::native {
namespace {
#if defined(USE_ROCM)
static const int BLOCKDIMY = 16;
static constexpr int BLOCKDIMY = 16;
#else
static const int BLOCKDIMY = 32;
static constexpr int BLOCKDIMY = 32;
#endif
template

View File

@ -82,7 +82,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
// lanczos approximation
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const accscalar_t lanczos_sum_expg_scaled_num[13] = {
constexpr accscalar_t lanczos_sum_expg_scaled_num[13] = {
0.006061842346248906525783753964555936883222,
0.5098416655656676188125178644804694509993,
19.51992788247617482847860966235652136208,
@ -97,7 +97,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
103794043.1163445451906271053616070238554,
56906521.91347156388090791033559122686859
};
static const accscalar_t lanczos_sum_expg_scaled_denom[13] = {
constexpr accscalar_t lanczos_sum_expg_scaled_denom[13] = {
1.,
66.,
1925.,
@ -126,10 +126,10 @@ __host__ __device__ scalar_t _igam_helper_fac(scalar_t a, scalar_t x) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
accscalar_t ax, fac, res, num, numfac;
static const accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ?
constexpr accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ?
7.09782712893383996843E2 : 88.72283905206835;
static const accscalar_t EXP1 = 2.718281828459045;
static const accscalar_t lanczos_g = 6.024680040776729583740234375;
constexpr accscalar_t EXP1 = 2.718281828459045;
constexpr accscalar_t lanczos_g = 6.024680040776729583740234375;
if (::fabs(a - x) > 0.4 * ::fabs(a)) {
ax = a * ::log(x) - x - ::lgamma(a);
@ -158,9 +158,9 @@ __host__ __device__ scalar_t _igam_helper_series(scalar_t a, scalar_t x) {
// Compute igam using DLMF 8.11.4. [igam1]
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
static const int MAXITER = 2000;
constexpr int MAXITER = 2000;
int i;
accscalar_t ans, ax, c, r;
@ -196,8 +196,8 @@ __host__ __device__ scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
accscalar_t fac = 1;
accscalar_t sum = 0;
accscalar_t term, logx;
static const int MAXITER = 2000;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
constexpr int MAXITER = 2000;
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
for (n = 1; n < MAXITER; n++) {
@ -219,7 +219,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const accscalar_t d[25][25] =
constexpr accscalar_t d[25][25] =
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, -1.9752288294349443e-15},
{-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, -4.13125571381061e-15},
{4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, 8.8592218725911273e-15},
@ -248,7 +248,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
int k, n, sgn;
int maxpow = 0;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
accscalar_t lambda = x / a;
accscalar_t sigma = (x - a) / a;
@ -314,12 +314,12 @@ __host__ __device__ scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar
int i;
accscalar_t ans, ax, c, yc, r, t, y, z;
accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2;
static const int MAXITER = 2000;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
constexpr int MAXITER = 2000;
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
static const accscalar_t BIG = std::is_same_v<accscalar_t,double> ?
constexpr accscalar_t BIG = std::is_same_v<accscalar_t,double> ?
4.503599627370496e15 : 16777216.;
static const accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ?
constexpr accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ?
2.22044604925031308085e-16 : 5.9604644775390625E-8;
ax = _igam_helper_fac(a, x);
@ -385,10 +385,10 @@ __noinline__ __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
accscalar_t absxma_a;
static const accscalar_t SMALL = 20.0;
static const accscalar_t LARGE = 200.0;
static const accscalar_t SMALLRATIO = 0.3;
static const accscalar_t LARGERATIO = 4.5;
constexpr accscalar_t SMALL = 20.0;
constexpr accscalar_t LARGE = 200.0;
constexpr accscalar_t SMALLRATIO = 0.3;
constexpr accscalar_t LARGERATIO = 4.5;
if ((x < 0) || (a < 0)) {
// out of defined-region of the function
@ -467,10 +467,10 @@ __noinline__ __host__ __device__ scalar_t calc_igamma(scalar_t a, scalar_t x) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
accscalar_t absxma_a;
static const accscalar_t SMALL = 20.0;
static const accscalar_t LARGE = 200.0;
static const accscalar_t SMALLRATIO = 0.3;
static const accscalar_t LARGERATIO = 4.5;
constexpr accscalar_t SMALL = 20.0;
constexpr accscalar_t LARGE = 200.0;
constexpr accscalar_t SMALLRATIO = 0.3;
constexpr accscalar_t LARGERATIO = 4.5;
// boundary values following SciPy
if ((x < 0) || (a < 0)) {

View File

@ -231,7 +231,7 @@ const auto lcm_string = jiterator_stringify(
const auto digamma_string = jiterator_stringify(
template <typename T>
T digamma(T x) {
static const double PI_f64 = 3.14159265358979323846;
static constexpr double PI_f64 = 3.14159265358979323846;
// Short-circuits if x is +/- 0 and returns -/+ ∞ per the C++ standard
if (x == 0) {
@ -3072,9 +3072,9 @@ template <typename scalar_t>
static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const double PI_f64 = 3.14159265358979323846;
const accscalar_t PSI_10 = 2.25175258906672110764;
const accscalar_t A[] = {
static constexpr double PI_f64 = 3.14159265358979323846;
constexpr accscalar_t PSI_10 = 2.25175258906672110764;
constexpr accscalar_t A[] = {
8.33333333333333333333E-2,
-2.10927960927960927961E-2,
7.57575757575757575758E-3,

View File

@ -277,7 +277,7 @@ struct BilinearFilterFunctor {
return 0;
}
static const int size = 2;
static constexpr int size = 2;
};
// taken from
@ -301,7 +301,7 @@ struct BicubicFilterFunctor {
return 0;
}
static const int size = 4;
static constexpr int size = 4;
};
template <typename accscalar_t>

View File

@ -416,7 +416,7 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
// else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k)
// else called from aten::mv, mat1.size = (m * n), mat2.size = (n)
// only m * n * b * k(if exist) are large enough we can get benefit from mkldnn optimized gemm kernel
static const int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
constexpr int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
if (mat1.dim() == 1 && mat2.dim() == 1) {
// aten::dot
return mat1.size(0) > mkldnn_gemm_min_size;

View File

@ -1,16 +1,16 @@
#pragma once
#include <c10/metal/common.h>
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
struct CatLargeSharedParams {
template <typename idx_type_t = int64_t, unsigned N = c10::metal::max_ndim>
struct CatSharedParams {
int32_t ndim;
int32_t cat_dim;
::c10::metal::array<idx_type_t, N> output_strides;
::c10::metal::array<idx_type_t, N> output_sizes;
};
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
struct CatLargeInputParams {
template <typename idx_type_t = int64_t, unsigned N = c10::metal::max_ndim>
struct CatInputParams {
idx_type_t cat_dim_offset;
idx_type_t input_element_offset;
::c10::metal::array<idx_type_t, N> input_strides;

View File

@ -6,12 +6,12 @@
using namespace metal;
using namespace c10::metal;
template <typename T_in, typename T_out>
kernel void cat_large(
template <typename I, typename T_in, typename T_out>
kernel void cat(
constant T_in* input [[buffer(0)]],
device T_out* output [[buffer(1)]],
constant CatLargeSharedParams<>& shared_params [[buffer(2)]],
constant CatLargeInputParams<>& input_params [[buffer(3)]],
constant CatSharedParams<I>& shared_params [[buffer(2)]],
constant CatInputParams<I>& input_params [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
auto ndim = shared_params.ndim;
auto cat_dim = shared_params.cat_dim;
@ -23,9 +23,9 @@ kernel void cat_large(
constant auto& input_strides = input_params.input_strides;
constant auto& input_sizes = input_params.input_sizes;
auto input_element_idx = static_cast<int64_t>(tid) + input_element_offset;
int64_t input_offset = 0;
int64_t output_offset = 0;
auto input_element_idx = static_cast<I>(tid) + input_element_offset;
I input_offset = 0;
I output_offset = 0;
for (auto dim = ndim - 1; dim >= 0; dim--) {
auto dim_size = input_sizes[dim];
@ -42,41 +42,45 @@ kernel void cat_large(
output[output_offset] = static_cast<T_out>(input[input_offset]);
}
#define REGISTER_CAT_LARGE_OP(T_in, T_out) \
template [[host_name("cat_large_" #T_in "_" #T_out)]] \
kernel void cat_large<T_in, T_out>( \
constant T_in * input [[buffer(0)]], \
device T_out * output [[buffer(1)]], \
constant CatLargeSharedParams<> & shared_params [[buffer(2)]], \
constant CatLargeInputParams<> & input_params [[buffer(3)]], \
#define REGISTER_CAT_OP(I, T_in, T_out) \
template [[host_name("cat_" #I "_" #T_in "_" #T_out)]] \
kernel void cat<I, T_in, T_out>( \
constant T_in * input [[buffer(0)]], \
device T_out * output [[buffer(1)]], \
constant CatSharedParams<I> & shared_params [[buffer(2)]], \
constant CatInputParams<I> & input_params [[buffer(3)]], \
uint tid [[thread_position_in_grid]]);
#define REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(T_out) \
REGISTER_CAT_LARGE_OP(float, T_out); \
REGISTER_CAT_LARGE_OP(half, T_out); \
REGISTER_CAT_LARGE_OP(bfloat, T_out); \
REGISTER_CAT_LARGE_OP(int, T_out); \
REGISTER_CAT_LARGE_OP(uint, T_out); \
REGISTER_CAT_LARGE_OP(long, T_out); \
REGISTER_CAT_LARGE_OP(ulong, T_out); \
REGISTER_CAT_LARGE_OP(short, T_out); \
REGISTER_CAT_LARGE_OP(ushort, T_out); \
REGISTER_CAT_LARGE_OP(char, T_out); \
REGISTER_CAT_LARGE_OP(uchar, T_out); \
REGISTER_CAT_LARGE_OP(bool, T_out);
#define REGISTER_CAT_OP_ALL_INPUT_TYPES(I, T_out) \
REGISTER_CAT_OP(I, float, T_out); \
REGISTER_CAT_OP(I, half, T_out); \
REGISTER_CAT_OP(I, bfloat, T_out); \
REGISTER_CAT_OP(I, int, T_out); \
REGISTER_CAT_OP(I, uint, T_out); \
REGISTER_CAT_OP(I, long, T_out); \
REGISTER_CAT_OP(I, ulong, T_out); \
REGISTER_CAT_OP(I, short, T_out); \
REGISTER_CAT_OP(I, ushort, T_out); \
REGISTER_CAT_OP(I, char, T_out); \
REGISTER_CAT_OP(I, uchar, T_out); \
REGISTER_CAT_OP(I, bool, T_out);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(float);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(half);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bfloat);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(int);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uint);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(long);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ulong);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(short);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ushort);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(char);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uchar);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bool);
#define REGISTER_CAT_FOR_INDEX_TYPE(I) \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, float); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, half); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, bfloat); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, int); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, uint); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, long); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, ulong); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, short); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, ushort); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, char); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, uchar); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, bool); \
\
REGISTER_CAT_OP(I, float2, float2); \
REGISTER_CAT_OP(I, half2, half2);
REGISTER_CAT_LARGE_OP(float2, float2);
REGISTER_CAT_LARGE_OP(half2, half2);
REGISTER_CAT_FOR_INDEX_TYPE(int64_t);
REGISTER_CAT_FOR_INDEX_TYPE(int32_t);

View File

@ -3,6 +3,7 @@
#include <ATen/MemoryOverlap.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/Pool.h>
#include <ATen/native/TensorShape.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/native/mps/OperationUtils.h>
@ -69,29 +70,40 @@ static void check_shape_except_dim(const Tensor& first, const Tensor& second, in
}
}
// This implementation of cat is used only if one of the inputs or the output is
// too large to use MPSGraph.
template <typename T>
std::string get_type_str();
template <>
std::string get_type_str<int64_t>() {
return "int64_t";
}
template <>
std::string get_type_str<int32_t>() {
return "int32_t";
}
// NOTE: `output` is expected to already have the correct size.
static void cat_out_large_tensor_mps(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) {
CatLargeSharedParams shared_params;
template <typename idx_type_t>
static void cat_out_mps_impl(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) {
CatSharedParams<idx_type_t> shared_params;
shared_params.ndim = output.dim();
shared_params.cat_dim = dimension;
for (const auto dim : c10::irange(output.dim())) {
shared_params.output_strides[dim] = output.stride(dim);
shared_params.output_sizes[dim] = output.size(dim);
shared_params.output_strides[dim] = safe_downcast<idx_type_t, int64_t>(output.stride(dim));
shared_params.output_sizes[dim] = safe_downcast<idx_type_t, int64_t>(output.size(dim));
}
int64_t cat_dim_offset = 0;
idx_type_t cat_dim_offset = 0;
size_t input_idx = 0;
MPSStream* stream = getCurrentMPSStream();
// Launch a separate kernels for each input. This will produce some overhead,
// but that should be relatively minimal since at least one of the inputs is
// very large. In order to launch only one kernel to process all inputs, we
// would have to copy all the input tensor data into a packed buffer, which
// would not be ideal.
// Launch a separate kernels for each input. This will produce some overhead.
// In order to launch only one kernel to process all inputs, we would have to
// copy all the input tensor data into a packed buffer, which would not be
// ideal.
for (const Tensor& input : inputs) {
if (input.numel() == 0) {
continue;
@ -104,21 +116,23 @@ static void cat_out_large_tensor_mps(const ITensorListRef& inputs, int64_t dimen
for (int64_t numel_remaining = input.numel(); numel_remaining > 0; numel_remaining -= max_num_threads) {
auto num_threads = std::min(max_num_threads, numel_remaining);
CatLargeInputParams input_params;
CatInputParams<idx_type_t> input_params;
input_params.cat_dim_offset = cat_dim_offset;
input_params.input_element_offset = input.numel() - numel_remaining;
input_params.cat_dim_offset = safe_downcast<idx_type_t, int64_t>(cat_dim_offset);
input_params.input_element_offset = safe_downcast<idx_type_t, int64_t>(input.numel() - numel_remaining);
for (const auto dim : c10::irange(input.dim())) {
input_params.input_strides[dim] = input.stride(dim);
input_params.input_sizes[dim] = input.size(dim);
input_params.input_strides[dim] = safe_downcast<idx_type_t, int64_t>(input.stride(dim));
input_params.input_sizes[dim] = safe_downcast<idx_type_t, int64_t>(input.size(dim));
}
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
auto pipeline_state = lib.getPipelineStateForFunc(
fmt::format("cat_large_{}_{}", scalarToMetalTypeString(input), scalarToMetalTypeString(output)));
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("cat_{}_{}_{}",
get_type_str<idx_type_t>(),
scalarToMetalTypeString(input),
scalarToMetalTypeString(output)));
getMPSProfiler().beginProfileKernel(pipeline_state, "cat", {input});
[computeEncoder setComputePipelineState:pipeline_state];
mtl_setArgs(computeEncoder, input, output, shared_params, input_params);
@ -294,13 +308,6 @@ TORCH_IMPL_FUNC(cat_out_mps)
" and out is on ",
out.device());
// TODO: For better performance by eliminating input tensor gathering and post transpose,
// TODO: it is better to keep the out tensor's memory format.
// TODO: dimension needs to be recomputed as:
// TODO: dim = 0 --> dim = 0; dim = 1 or 2 --> dim = out.dim()- dim; otherwise dim = dim-1
if (needsGather(out)) {
out.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
}
std::vector<int64_t> size(notSkippedTensor.sizes().vec());
// Compute size of the result in the cat dimension
@ -331,82 +338,9 @@ TORCH_IMPL_FUNC(cat_out_mps)
has_large_tensor |= isTooLargeForMPSGraph(out);
if (has_large_tensor) {
return mps::cat_out_large_tensor_mps(materialized_inputs, dimension, out);
}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
std::vector<MPSGraphTensor*> inputTensors_;
MPSGraphTensor* outputTensor_ = nil;
};
@autoreleasepool {
std::string key = "cat_out_mps:" + std::to_string(dimension) + ":" +
(memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
if (!all_same_dtype) {
key += getTensorsStringKey(input_tensors, true, all_same_sizes_and_stride);
} else {
key += ":" + getMPSTypeString(input_tensors[0].scalar_type(), true) + ":" + std::to_string(inputs.size());
}
for (auto idx : skipped_tensor_indices) {
key += "," + std::to_string(idx);
}
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto len_tensor_array = inputs.size() - skipped_tensor_indices.size();
std::vector<MPSGraphTensor*> castInputTensors(len_tensor_array);
newCachedGraph->inputTensors_.reserve(len_tensor_array);
for (const auto idx : c10::irange(len_tensor_array)) {
const Tensor& tensor = input_tensors[idx];
auto scalar_type = getMPSScalarType(tensor.scalar_type());
if (tensor.scalar_type() == kBool) {
scalar_type = MPSDataTypeInt8;
}
newCachedGraph->inputTensors_[idx] = mpsGraphUnrankedPlaceHolder(mpsGraph, scalar_type);
if (tensor.scalar_type() != out_dtype) {
castInputTensors[idx] = [mpsGraph castTensor:newCachedGraph->inputTensors_[idx]
toType:getMPSDataType(out_dtype)
name:@"castInput"];
} else {
castInputTensors[idx] = newCachedGraph->inputTensors_[idx];
}
}
auto inputTensorsArray = [NSArray arrayWithObjects:castInputTensors.data() count:len_tensor_array];
MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray
dimension:dimension // Maybe convert this from int64_t -> int32
name:nil];
if (getMPSDataType(out_dtype) == MPSDataTypeBool) {
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"outputTensor"];
}
newCachedGraph->outputTensor_ = outputTensor;
});
std::vector<Placeholder> inputPlaceholders;
int i = 0;
int t_idx = 0;
for (const Tensor& tensor : materialized_inputs) {
if (std::find(skipped_tensor_indices.begin(), skipped_tensor_indices.end(), i) == skipped_tensor_indices.end()) {
auto scalar_type = getMPSScalarType(tensor.scalar_type());
if (tensor.scalar_type() == kBool) {
scalar_type = MPSDataTypeInt8;
}
inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx], tensor, nullptr, true, scalar_type);
t_idx++;
}
i++;
}
auto outputDataType = getMPSScalarType(out.scalar_type());
Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor_, out, /*mpsShape=*/nil, /*gatherTensorData=*/false, outputDataType);
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
for (auto& inputPlaceholder : inputPlaceholders) {
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
}
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder);
return mps::cat_out_mps_impl<int64_t>(materialized_inputs, dimension, out);
} else {
return mps::cat_out_mps_impl<int32_t>(materialized_inputs, dimension, out);
}
}

View File

@ -6531,6 +6531,7 @@
dispatch:
CPU, CUDA: var
MPS: var_mps
MTIA: var_mtia
tags: core
- func: var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)

View File

@ -3551,7 +3551,7 @@ void dequantize_tensor_per_tensor_affine_cpu(
#if defined(__ARM_NEON__) || defined(__aarch64__)
const static int PARALLEL_THRESHOLD = 1 << 20;
constexpr static int PARALLEL_THRESHOLD = 1 << 20;
// Generic template defaults to naive quantize implementation
template <typename T>

View File

@ -1388,7 +1388,7 @@ namespace at::native {
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1,
"onednn int8 linear: act scale/zp size should be 1/<=1");
static std::optional<at::Tensor> other = std::nullopt;
static const std::string_view binary_post_op = "none";
constexpr std::string_view binary_post_op = "none";
int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0;
return linear_int8_with_onednn_weight(
act, act_scale.item().toDouble(), act_zp,

View File

@ -16,8 +16,8 @@ namespace {
#ifdef USE_PYTORCH_QNNPACK
const static float qnnpack_softmax_output_scale = 0x1.0p-8f;
const static int qnnpack_softmax_output_zero_point = 0;
constexpr static float qnnpack_softmax_output_scale = 0x1.0p-8f;
constexpr static int qnnpack_softmax_output_zero_point = 0;
bool is_qnnpack_compatible(
const Tensor& qx,

View File

@ -110,9 +110,9 @@ class ApplyLogSumExp {
using ElementCompute = ElementCompute_;
using ElementLSE = ElementLSE_;
static int const kElementsPerAccess = ElementsPerAccess;
static int const kCount = kElementsPerAccess;
static const ScaleType::Kind kScale =
static int constexpr kElementsPerAccess = ElementsPerAccess;
static int constexpr kCount = kElementsPerAccess;
static constexpr ScaleType::Kind kScale =
cutlass::epilogue::thread::ScaleType::NoBetaScaling;
using FragmentOutput = Array<ElementOutput, kCount>;

View File

@ -14,16 +14,16 @@ using namespace at;
namespace {
const auto int_min = std::numeric_limits<int>::min();
const auto int_max = std::numeric_limits<int>::max();
const auto long_min = std::numeric_limits<int64_t>::min();
const auto long_max = std::numeric_limits<int64_t>::max();
const auto float_lowest = std::numeric_limits<float>::lowest();
const auto float_min = std::numeric_limits<float>::min();
const auto float_max = std::numeric_limits<float>::max();
const auto double_lowest = std::numeric_limits<double>::lowest();
const auto double_min = std::numeric_limits<double>::min();
const auto double_max = std::numeric_limits<double>::max();
constexpr auto int_min = std::numeric_limits<int>::min();
constexpr auto int_max = std::numeric_limits<int>::max();
constexpr auto long_min = std::numeric_limits<int64_t>::min();
constexpr auto long_max = std::numeric_limits<int64_t>::max();
constexpr auto float_lowest = std::numeric_limits<float>::lowest();
constexpr auto float_min = std::numeric_limits<float>::min();
constexpr auto float_max = std::numeric_limits<float>::max();
constexpr auto double_lowest = std::numeric_limits<double>::lowest();
constexpr auto double_min = std::numeric_limits<double>::min();
constexpr auto double_max = std::numeric_limits<double>::max();
const std::vector<int> ints {
int_min,

View File

@ -146,9 +146,9 @@ uint64_t XPUGeneratorImpl::seed() {
c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
// The RNG state comprises the seed, and an offset used for Philox.
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(uint64_t);
static const size_t total_size = seed_size + offset_size;
constexpr size_t seed_size = sizeof(uint64_t);
constexpr size_t offset_size = sizeof(uint64_t);
constexpr size_t total_size = seed_size + offset_size;
// The internal state is returned as a CPU byte tensor.
auto state_tensor = at::detail::empty_cpu(
@ -170,9 +170,9 @@ c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
void XPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
at::xpu::assertNotCapturing(
"Please ensure to utilize the XPUGeneratorImpl::set_state_index method during capturing.");
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(uint64_t);
static const size_t total_size = seed_size + offset_size;
constexpr size_t seed_size = sizeof(uint64_t);
constexpr size_t offset_size = sizeof(uint64_t);
constexpr size_t total_size = seed_size + offset_size;
at::detail::check_rng_state(new_state);

View File

@ -6,7 +6,7 @@ import os
import subprocess
import sys
import tempfile
from typing import Callable
from collections.abc import Callable
from torch._inductor.utils import fresh_cache

View File

@ -1,7 +1,8 @@
import os
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Callable, Optional
from typing import Any, Optional
import matplotlib.pyplot as plt

View File

@ -1,4 +1,5 @@
from typing import Any, Callable
from collections.abc import Callable
from typing import Any
import torch

View File

@ -1,7 +1,8 @@
import time
from argparse import ArgumentParser
from collections import defaultdict
from typing import Any, Callable, NamedTuple
from collections.abc import Callable
from typing import Any, NamedTuple
import torch
from torch.autograd import functional

View File

@ -1,5 +1,6 @@
from collections import defaultdict
from typing import Callable, Optional, Union
from collections.abc import Callable
from typing import Optional, Union
import torch
from torch import nn, Tensor

View File

@ -1,5 +1,6 @@
import dataclasses
from typing import Callable, Optional
from collections.abc import Callable
from typing import Optional
all_experiments: dict[str, Callable] = {}

View File

@ -9,8 +9,9 @@ import logging
import time
from abc import abstractmethod
from collections import defaultdict
from collections.abc import Callable
from dataclasses import asdict, dataclass, field
from typing import Any, Callable, Optional
from typing import Any, Optional
from tabulate import tabulate
from tqdm import tqdm

View File

@ -7,6 +7,7 @@ from pt import ( # noqa: F401
binary_inplace_test,
binary_test,
bmm_test,
boolean_test,
cat_test,
channel_shuffle_test,
chunk_test,

View File

@ -56,6 +56,9 @@ binary_ops_list = op_bench.op_list(
["sub", torch.sub],
["div", torch.div],
["mul", torch.mul],
["asr", torch.bitwise_right_shift],
["lsl", torch.bitwise_left_shift],
["xor", torch.bitwise_xor],
],
)

View File

@ -0,0 +1,73 @@
import operator_benchmark as op_bench
import torch
"""Microbenchmarks for boolean operators. Supports both Caffe2/PyTorch."""
# Configs for PT all operator
all_long_configs = op_bench.cross_product_configs(
M=[8, 128], N=[32, 64], K=[256, 512], device=["cpu", "cuda"], tags=["long"]
)
all_short_configs = op_bench.config_list(
attr_names=["M", "N", "K"],
attrs=[
[1, 1, 1],
[64, 64, 64],
[64, 64, 128],
],
cross_product_configs={
"device": ["cpu", "cuda"],
},
tags=["short"],
)
class AllBenchmark(op_bench.TorchBenchmarkBase):
def init(self, M, N, K, device):
self.inputs = {
"input_one": torch.randint(0, 2, (M, N, K), device=device, dtype=torch.bool)
}
self.set_module_name("all")
def forward(self, input_one):
return torch.all(input_one)
# The generated test names based on all_short_configs will be in the following pattern:
# all_M8_N16_K32_devicecpu
# all_M8_N16_K32_devicecpu_bwdall
# all_M8_N16_K32_devicecpu_bwd1
# all_M8_N16_K32_devicecpu_bwd2
# ...
# Those names can be used to filter tests.
op_bench.generate_pt_test(all_long_configs + all_short_configs, AllBenchmark)
"""Mircobenchmark for any operator."""
class AnyBenchmark(op_bench.TorchBenchmarkBase):
def init(self, M, N, device):
self.inputs = {
"input_one": torch.randint(0, 2, (M, N), device=device, dtype=torch.bool)
}
self.set_module_name("any")
def forward(self, input_one):
return torch.any(input_one)
any_configs = op_bench.cross_product_configs(
M=[8, 256],
N=[256, 16],
device=["cpu", "cuda"],
tags=["any"],
)
op_bench.generate_pt_test(any_configs, AnyBenchmark)
if __name__ == "__main__":
op_bench.benchmark_runner.main()

View File

@ -1,7 +1,8 @@
import itertools
from collections.abc import Callable
from dataclasses import asdict, dataclass
from functools import partial
from typing import Callable, Union
from typing import Union
import numpy as np
from tabulate import tabulate

View File

@ -3,10 +3,11 @@ import csv
import itertools
import random
from collections import defaultdict
from collections.abc import Callable
from contextlib import nullcontext
from dataclasses import asdict, dataclass
from functools import partial
from typing import Callable, Optional, Union
from typing import Optional, Union
import numpy as np
from tabulate import tabulate

View File

@ -1,8 +1,8 @@
import itertools
from collections import defaultdict
from collections.abc import Callable
from contextlib import nullcontext
from dataclasses import asdict, dataclass
from typing import Callable
from tabulate import tabulate
from tqdm import tqdm

View File

@ -1729,8 +1729,10 @@ def define_buck_targets(
"torch/csrc/jit/backends/backend_debug_info.cpp",
"torch/csrc/jit/backends/backend_interface.cpp",
],
compiler_flags = get_pt_compiler_flags(),
fbandroid_compiler_flags = c2_fbandroid_xplat_compiler_flags,
compiler_flags = get_pt_compiler_flags() + select({
"DEFAULT": [],
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags
}),
# @lint-ignore BUCKLINT link_whole
link_whole = True,
linker_flags = get_no_as_needed_linker_flag(),
@ -2023,6 +2025,9 @@ def define_buck_targets(
"ovr_config//os:android-x86_64": [
"-mssse3",
],
}) + select({
"DEFAULT": [],
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags,
}),
exported_preprocessor_flags = get_aten_preprocessor_flags(),
exported_deps = [

View File

@ -3,11 +3,11 @@ from __future__ import annotations
import dis
import inspect
import sys
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing import Any, Optional, TYPE_CHECKING, Union
if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import torch
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten

View File

@ -5,7 +5,7 @@ Python implementation of function wrapping functionality for functorch.dim.
from __future__ import annotations
import functools
from typing import Any, Callable, Optional
from typing import Any, Optional, TYPE_CHECKING
import torch
from torch.utils._pytree import tree_map
@ -15,6 +15,10 @@ from ._enable_all_layers import EnableAllLayers
from ._tensor_info import TensorInfo
if TYPE_CHECKING:
from collections.abc import Callable
def handle_from_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Handle tensor conversion for torch function integration."""
return tensor

View File

@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
import functools
from collections.abc import Callable
from types import (
BuiltinMethodType,
FunctionType,
@ -12,7 +13,7 @@ from types import (
MethodDescriptorType,
WrapperDescriptorType,
)
from typing import Any, Callable
from typing import Any
FUNC_TYPES = (

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import functools
from typing import Callable, TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Union
import torch
from functorch.dim import dims # noqa: F401
@ -16,7 +16,7 @@ from ._parsing import (
if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Callable, Sequence
__all__ = ["rearrange"]

View File

@ -23,7 +23,7 @@ project-excludes = [
# ==== below will be enabled directory by directory ====
# ==== to test Pyrefly on a specific directory, simply comment it out ====
"torch/_inductor/runtime",
"torch/_inductor/codegen",
"torch/_inductor/codegen/triton.py",
# formatting issues, will turn on after adjusting where suppressions can be
# in import statements
"torch/linalg/__init__.py",

View File

@ -1023,7 +1023,7 @@ class DTensorMeshTest(DTensorTestBase):
DTensorMeshTestWithLocalTensor = create_local_tensor_test_class(
DTensorMeshTest,
skipped_tests=[
# Submeshes are not supported by local tensor mode
# Test asserts must be rewritten for local tensor
"test_from_local_sub_mesh",
"test_default_value_sub_mesh",
"test_redistribute_sub_mesh",

View File

@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
from unittest.mock import patch
import torch
from torch.distributed.tensor import distribute_tensor, DTensor
from torch.distributed.tensor.placement_types import Replicate
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
from torch.testing._internal.inductor_utils import GPU_TYPE
from torch.testing._internal.triton_utils import requires_gpu
class TestDynamic(DTensorTestBase):
@requires_gpu
@with_comms
@parametrize("fake_tensor_cache_enabled", [False, True])
def test_embedding(self, fake_tensor_cache_enabled):
with patch.object(
torch._dynamo.config, "fake_tensor_cache_enabled", fake_tensor_cache_enabled
):
device_mesh = self.build_device_mesh()
placements = (Replicate(),)
num_embeddings = 202048
embedding_dim = 256
weight = distribute_tensor(
torch.rand(
[num_embeddings, embedding_dim],
dtype=torch.float32,
device=GPU_TYPE,
requires_grad=True,
),
device_mesh,
placements, # [Replicate()],
)
def forward(input_batch_inputs_):
to = weight.to(torch.float32)
emb = torch.nn.functional.embedding(input_batch_inputs_, to)
return emb
arg0 = torch.randint(
low=0, high=100, size=(2, 512), dtype=torch.int64, device=GPU_TYPE
)
arg0 = DTensor.from_local(arg0, device_mesh, placements)
compiled_forward = torch.compile(forward, fullgraph=True, dynamic=True)
_out = compiled_forward(arg0)
instantiate_parametrized_tests(TestDynamic)
if __name__ == "__main__":
run_tests()

View File

@ -30,6 +30,7 @@ from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.placement_types import _StridedShard, Placement
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
DTensorTestBase,
with_comms,
)
@ -647,7 +648,7 @@ class TestViewOps(DTensorTestBase):
@with_comms
def test_squeeze_(self):
mesh_2d = init_device_mesh(self.device_type, (3, 2), mesh_dim_names=("a", "b"))
torch.manual_seed(self.rank)
self.init_manual_seed_for_rank()
x = torch.randn((1, 4), device=self.device_type)
dist_x = DTensor.from_local(x, mesh_2d, [Partial(), Shard(1)])
self._test_op_on_dtensor(
@ -664,5 +665,13 @@ class TestViewOps(DTensorTestBase):
self.assertEqual(dist_x.placements, [Partial(), Shard(0)])
TestViewOpsWithLocalTensor = create_local_tensor_test_class(
TestViewOps,
skipped_tests=[
# Comparing data pointers is not supported for local tensor
"test_dtensor_view_op_uneven",
],
)
if __name__ == "__main__":
run_tests()

View File

@ -7,8 +7,13 @@ from dataclasses import dataclass
import torch
from torch.multiprocessing.reductions import reduce_tensor
from torch.testing._internal.common_cuda import SM100OrLater
from torch.testing._internal.common_distributed import MultiProcContinuousTest
from torch.testing._internal.common_utils import requires_cuda_p2p_access, run_tests
from torch.testing._internal.common_utils import (
requires_cuda_p2p_access,
run_tests,
skip_but_pass_in_sandcastle_if,
)
# So that tests are written in device-agnostic way
@ -59,6 +64,10 @@ class CupyAsTensorTest(MultiProcContinuousTest):
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@skip_but_pass_in_sandcastle_if(
SM100OrLater,
"Fails if ran in docker environment without privileged access (https://github.com/pytorch/pytorch/issues/165170)",
)
def test_cupy_as_tensor(self) -> None:
"""
Test that torch.as_tensor works for cupy array interface

View File

@ -1664,14 +1664,14 @@ class CuTeLayoutTest(TestCase):
def test_remap_to_tensor(self):
"""Test the remap_to_tensor method for various scenarios."""
# Test 1: Consecutive ranks, full world - should return logical groups directly
original_mesh = torch.tensor([0, 1, 2, 3], dtype=torch.int)
original_mesh = torch.tensor([[0, 1], [2, 3]], dtype=torch.int)
layout1 = _Layout((2, 2), (2, 1)) # row-major 2x2
result1 = layout1.remap_to_tensor(original_mesh)
expected1 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
self.assertEqual(result1, expected1)
# Test 2: Non-consecutive ranks - should map to actual ranks
original_mesh = torch.tensor([10, 20, 30, 40], dtype=torch.int)
original_mesh = torch.tensor([[10, 20], [30, 40]], dtype=torch.int)
layout2 = _Layout((2, 2), (2, 1))
result2 = layout2.remap_to_tensor(original_mesh)
expected2 = torch.tensor([[[10, 20], [30, 40]]], dtype=torch.int)
@ -1692,7 +1692,7 @@ class CuTeLayoutTest(TestCase):
self.assertEqual(result5, expected5)
# Test 6: Tensor Cute representation of a 2D mesh
original_mesh = torch.tensor([0, 2, 1, 3], dtype=torch.int)
original_mesh = torch.tensor([[0, 2], [1, 3]], dtype=torch.int)
layout6 = _Layout((2, 2), (1, 2)) # column-major style
result6 = layout6.remap_to_tensor(original_mesh)
expected6 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)

View File

@ -12,6 +12,7 @@ import torch.distributed._symmetric_memory as symm_mem
import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem
from torch._inductor.runtime.triton_compat import triton
from torch.distributed._symmetric_memory._nvshmem_triton import requires_nvshmem
from torch.testing._internal.common_cuda import SM100OrLater
from torch.testing._internal.common_distributed import MultiProcContinuousTest
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
@ -264,6 +265,10 @@ def my_reduce_kernel(
nvshmem.reduce(team_handle, dest_tensor, source_tensor, nreduce, operation)
@skip_but_pass_in_sandcastle_if(
SM100OrLater,
"Skipping all NVSHMEM Triton tests due to https://github.com/pytorch/pytorch/issues/162897",
)
@instantiate_parametrized_tests
class NVSHMEMTritonTest(MultiProcContinuousTest):
def _init_device(self) -> None:

View File

@ -52,6 +52,9 @@ from torch.testing._internal.common_utils import (
test_contexts = [nullcontext, _test_mode]
# Set environment variable to disable multicast for all tests in this module
os.environ["TORCH_SYMM_MEM_DISABLE_MULTICAST"] = "1"
# So that tests are written in device-agnostic way
device_type = "cuda"
device_module = torch.get_device_module(device_type)
@ -549,6 +552,10 @@ class AsyncTPTest(MultiProcContinuousTest):
@skipUnless(SM89OrLater, "Requires compute capability >= 8.9")
@parametrize("scatter_dim", [0, 1])
@parametrize("rowwise", [True, False])
@skipIf(
SM100OrLater,
"https://github.com/pytorch/pytorch/issues/162940",
)
def test_fused_scaled_matmul_reduce_scatter(
self, scatter_dim: int, rowwise: bool
) -> None:

View File

@ -510,6 +510,7 @@ class TestDynamoTimed(TestCase):
raw = dataclasses.asdict(compilation_events[0])
del raw["feature_usage"]
del raw["ir_count"]
del raw["inductor_provenance"]
del raw["param_numel"]
del raw["param_bytes"]
del raw["param_count"]
@ -694,6 +695,7 @@ class TestDynamoTimed(TestCase):
raw = dataclasses.asdict(compilation_events[1])
del raw["feature_usage"]
del raw["ir_count"]
del raw["inductor_provenance"]
del raw["guard_latency_us"]
del raw["param_numel"]
del raw["param_bytes"]
@ -911,6 +913,27 @@ class TestDynamoTimed(TestCase):
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
self.assertEqual(compilation_events[0].ir_count, second)
@dynamo_config.patch(
{
"log_compilation_metrics": True,
}
)
@inductor_config.patch(
{"trace.enabled": True, "trace.provenance_tracking_level": 1},
)
def test_inductor_provenance(self):
module = torch.nn.Linear(6, 66)
graph_module = torch.fx.symbolic_trace(module)
compilation_events = []
with mock.patch("torch._dynamo.utils.log_compilation_event") as log_event:
torch.compile(graph_module)(torch.randn(6, 6))
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
self.assertEqual(
compilation_events[0].inductor_provenance,
{'{"extern_kernels.addmm:1": []}'},
)
@dynamo_config.patch({"log_compilation_metrics": True})
@inductor_config.patch({"force_disable_caches": True})
def test_dynamic_shape_feature_use(self):

View File

@ -57,7 +57,7 @@ def graph_capture(model, inputs, with_export):
with ExitStack() as stack:
joint_with_descriptors = aot_export_joint_with_descriptors(
stack,
model,
gm,
inputs,
)
return joint_with_descriptors.graph_module

View File

@ -2340,6 +2340,39 @@ class AOTInductorTestsTemplate:
dynamic_shapes=dynamic_shapes,
)
def test_cond_symint_input_disable_one_pass(self):
class M(torch.nn.Module):
def forward(self, x, y, z):
a = y.shape[0]
b = z.shape[0]
def true_fn(x):
return x + a
def false_fn(x):
return x + b * z
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
input1 = (
torch.ones(3, 3, device=self.device),
torch.ones(5, device=self.device),
torch.ones(3, 3, device=self.device),
)
input2 = (
torch.ones(10, 3, device=self.device),
torch.ones(6, device=self.device),
torch.ones(10, 3, device=self.device),
)
inputs = (input1, input2)
dynamic_shapes = {"x": {0: Dim("d")}, "y": {0: Dim("d1")}, "z": {0: Dim("d")}}
with torch._inductor.config.patch({"triton.autotune_at_compile_time": False}):
self.check_model_with_multiple_inputs(
M(),
inputs,
dynamic_shapes=dynamic_shapes,
)
def test_while_loop_simple(self):
inputs = (
torch.randn((10, 20), device=self.device),

View File

@ -0,0 +1,371 @@
# Owner(s): ["module: inductor"]
import os
import platform
import tempfile
import unittest
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
import torch
import torch._inductor.config
from torch._inductor.test_case import TestCase
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_gpu
@dataclass
class ModelTestConfig:
"""Configuration for a model test case."""
name: str
model_class: type
example_inputs: tuple[torch.Tensor, ...]
dynamic_shapes: Optional[dict[str, Any]] = None
inductor_configs: Optional[dict[str, Any]] = None
rtol: float = 1e-4
atol: float = 1e-4
class WindowsCrossCompilationTestFramework:
"""
Framework for testing cross-compilation from Linux to Windows.
Provides reusable logic for creating compile and load test methods.
"""
_base_path: Optional[Path] = None
_win_torch_libs_path: Optional[str] = None
@classmethod
def base_path(cls) -> Path:
"""Get or create the base path for package files."""
if cls._base_path is None:
cls._base_path = Path(tempfile.mkdtemp(prefix="aoti_cross_compile_"))
return cls._base_path
@classmethod
def set_base_path(cls, path: Optional[Path | str] = None) -> None:
"""Set the base path for package files."""
cls._base_path = Path(path) if path else None
@classmethod
def set_win_torch_libs_path(cls, path: Optional[str] = None) -> None:
"""Set the path for Windows torch libs."""
cls._win_torch_libs_path = path
@classmethod
def get_package_path(cls, model_name: str) -> str:
"""Get the path for a model's .pt2 package file."""
package_dir = cls.base_path()
package_dir.mkdir(parents=True, exist_ok=True)
return str(package_dir / f"{model_name}_windows.pt2")
@classmethod
def get_win_torch_libs_path(cls) -> str:
"""Get the path for Windows torch libs."""
if cls._win_torch_libs_path is None:
raise RuntimeError("Windows torch libs path not set")
return str(cls._win_torch_libs_path)
@classmethod
def create_compile_test(cls, config: ModelTestConfig):
"""Create a compile test method for a model configuration."""
def compile_test(self):
if platform.system() == "Windows":
raise unittest.SkipTest(
"This test should run on Linux for cross-compilation"
)
self.assertTrue("WINDOWS_CUDA_HOME" in os.environ)
with torch.no_grad():
# Windows cross-compilation is only used for GPU.
# AOTI for CPU should be able to work as native compilation on Windows.
device = GPU_TYPE
model = config.model_class().to(device=device)
example_inputs = config.example_inputs
# Inputs should already be on GPU_TYPE but ensure they are
example_inputs = tuple(inp.to(device) for inp in example_inputs)
# Export the model
exported = torch.export.export(
model, example_inputs, dynamic_shapes=config.dynamic_shapes
)
# Prepare inductor configs
inductor_configs = {
"aot_inductor.cross_target_platform": "windows",
"aot_inductor.precompile_headers": False,
"aot_inductor.package_constants_on_disk_format": "binary_blob",
"aot_inductor.package_constants_in_so": False,
"aot_inductor.aoti_shim_library_path": cls.get_win_torch_libs_path(),
}
if config.inductor_configs:
inductor_configs.update(config.inductor_configs)
# Compile and package directly to the expected location
package_path = cls.get_package_path(config.name)
torch._inductor.aoti_compile_and_package(
exported,
package_path=package_path,
inductor_configs=inductor_configs,
)
self.assertTrue(
os.path.exists(package_path),
f"Package file should exist at {package_path}",
)
return compile_test
@classmethod
def create_load_test(cls, config: ModelTestConfig):
"""Create a load test method for a model configuration."""
def load_test(self):
if platform.system() != "Windows":
raise unittest.SkipTest("This test should run on Windows")
if not HAS_GPU:
raise unittest.SkipTest("Test requires GPU")
package_path = cls.get_package_path(config.name)
if not os.path.exists(package_path):
raise unittest.SkipTest(
f"Package file not found at {package_path}. "
f"Run test_{config.name}_compile first."
)
with torch.no_grad():
# Windows cross-compilation is only used for GPU.
# AOTI for CPU should be able to work as native compilation on Windows.
device = GPU_TYPE
# Create original model for comparison
original_model = config.model_class().to(device=device)
example_inputs = config.example_inputs
# Inputs should already be on GPU_TYPE but ensure they are
example_inputs = tuple(inp.to(device) for inp in example_inputs)
# Load the compiled package
loaded_model = torch._inductor.aoti_load_package(package_path)
# Test with the same inputs
original_output = original_model(*example_inputs)
loaded_output = loaded_model(*example_inputs)
# Compare outputs
torch.testing.assert_close(
original_output, loaded_output, rtol=config.rtol, atol=config.atol
)
return load_test
def auto_generate_tests(test_class):
"""
Class decorator to automatically generate compile/load test methods
from _define_* methods that return ModelTestConfig.
"""
# Find all _define_* methods that return ModelTestConfig
define_methods = {}
for name in dir(test_class):
if name.startswith("_define_") and callable(getattr(test_class, name)):
method = getattr(test_class, name)
# Try to call the method to see if it returns ModelTestConfig
try:
# Create a temporary instance to call the method
temp_instance = test_class.__new__(test_class)
result = method(temp_instance)
if isinstance(result, ModelTestConfig):
define_methods[name] = result
except Exception:
# If method fails, skip it
pass
# Generate compile/load methods for each discovered definition
for define_name, config in define_methods.items():
model_name = define_name[8:] # Remove '_define_' prefix
# Create compile test method
compile_method_name = f"test_{model_name}_compile"
compile_method = WindowsCrossCompilationTestFramework.create_compile_test(
config
)
compile_method.__name__ = compile_method_name
compile_method.__doc__ = f"Step 1: Cross-compile {model_name} model on Linux"
compile_method = requires_gpu()(compile_method)
setattr(test_class, compile_method_name, compile_method)
# Create load test method
load_method_name = f"test_{model_name}_load"
load_method = WindowsCrossCompilationTestFramework.create_load_test(config)
load_method.__name__ = load_method_name
load_method.__doc__ = f"Step 2: Load and test {model_name} model on Windows"
load_method = requires_gpu()(load_method)
setattr(test_class, load_method_name, load_method)
return test_class
@auto_generate_tests
class TestAOTInductorWindowsCrossCompilation(TestCase):
"""
Test class for AOT Inductor Windows cross-compilation.
Define test methods that return ModelTestConfig, and the decorator
will auto-generate compile/load test methods.
"""
def _define_simple(self):
"""Define the Simple model and its test configuration."""
class Simple(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(16, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return x
return ModelTestConfig(
name="simple",
model_class=Simple,
example_inputs=(torch.randn(8, 10, device=GPU_TYPE),),
dynamic_shapes={"x": {0: torch.export.Dim("batch", min=1, max=1024)}},
)
def _define_simple_cnn(self):
"""Define the SimpleCNN model and its test configuration."""
class SimpleCNN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 16, 3)
self.relu = torch.nn.ReLU()
self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))
self.fc = torch.nn.Linear(16, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.pool(x)
x = x.flatten(1)
x = self.fc(x)
return x
return ModelTestConfig(
name="simple_cnn",
model_class=SimpleCNN,
example_inputs=(torch.randn(2, 3, 32, 32, device=GPU_TYPE),),
dynamic_shapes={"x": {0: torch.export.Dim("batch", min=1, max=16)}},
rtol=1e-3,
atol=1e-3,
)
def _define_transformer(self):
"""Define the SimpleTransformer model and its test configuration."""
class SimpleTransformer(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Linear(128, 256)
self.attention = torch.nn.MultiheadAttention(256, 8, batch_first=True)
self.norm1 = torch.nn.LayerNorm(256)
self.ffn = torch.nn.Sequential(
torch.nn.Linear(256, 1024),
torch.nn.ReLU(),
torch.nn.Linear(1024, 256),
)
self.norm2 = torch.nn.LayerNorm(256)
self.output = torch.nn.Linear(256, 10)
def forward(self, x):
# x shape: (batch, seq_len, input_dim)
x = self.embedding(x)
attn_out, _ = self.attention(x, x, x)
x = self.norm1(x + attn_out)
ffn_out = self.ffn(x)
x = self.norm2(x + ffn_out)
x = x.mean(dim=1) # Global average pooling
x = self.output(x)
return x
return ModelTestConfig(
name="transformer",
model_class=SimpleTransformer,
example_inputs=(torch.randn(4, 16, 128, device=GPU_TYPE),),
dynamic_shapes={"x": {0: torch.export.Dim("batch", min=1, max=32)}},
rtol=1e-3,
atol=1e-3,
)
if __name__ == "__main__":
import sys
from torch._inductor.test_case import run_tests
# Check for --package-dir argument and remove it before unittest sees it
package_dir = None
win_torch_lib_dir = None
filtered_argv = []
i = 0
while i < len(sys.argv):
if sys.argv[i] == "--package-dir":
if i + 1 < len(sys.argv):
package_dir = sys.argv[i + 1]
i += 2 # Skip both --package-dir and its value
else:
print("Error: --package-dir requires a valid directory path")
sys.exit(1)
elif sys.argv[i].startswith("--package-dir="):
package_dir = sys.argv[i].split("=", 1)[1]
i += 1
elif sys.argv[i] == "--win-torch-lib-dir":
if i + 1 < len(sys.argv):
win_torch_lib_dir = sys.argv[i + 1]
i += 2 # Skip both --win-torch-lib-dir and its value
else:
print("Error: --win-torch-lib-dir requires a valid directory path")
sys.exit(1)
elif sys.argv[i].startswith("--win-torch-lib-dir="):
win_torch_lib_dir = sys.argv[i].split("=", 1)[1]
i += 1
else:
filtered_argv.append(sys.argv[i])
i += 1
# Validate and set the base path for package storage
if package_dir:
try:
package_path = Path(package_dir)
package_path.mkdir(parents=True, exist_ok=True)
# Test write access
test_file = package_path / ".test_write"
test_file.touch()
test_file.unlink()
WindowsCrossCompilationTestFramework.set_base_path(package_path)
except Exception:
print("Error: --package-dir requires a valid directory path")
sys.exit(1)
# Set Windows torch libs path if provided (only needed for compile tests)
if win_torch_lib_dir:
WindowsCrossCompilationTestFramework.set_win_torch_libs_path(win_torch_lib_dir)
# Update sys.argv to remove our custom arguments
sys.argv = filtered_argv
if HAS_GPU:
run_tests(needs="filelock")

View File

@ -623,7 +623,8 @@ class TestFP8Lowering(TestCase):
bias,
)
FileCheck().check("SCALING_ROWWISE : tl.constexpr = False").run(code[0])
FileCheck().check("SCALE_RECIPE_A : tl.constexpr = 0").run(code[0])
FileCheck().check("SCALE_RECIPE_B : tl.constexpr = 0").run(code[0])
self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype)
# depending on the kernel config (BLOCK_M size, etc) selected during Inductor
@ -768,7 +769,8 @@ class TestFP8Lowering(TestCase):
bias,
)
FileCheck().check("SCALING_ROWWISE : tl.constexpr = True").run(code[0])
FileCheck().check("SCALE_RECIPE_A : tl.constexpr = 1").run(code[0])
FileCheck().check("SCALE_RECIPE_B : tl.constexpr = 1").run(code[0])
self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype)
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)

View File

@ -8423,22 +8423,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
self.assertEqual(fn(x[0:]), x[16:][:16])
self.assertEqual(fn(x[128:]), x[128 + 16 :][:16])
def test_index_float_zero(self):
def fn(arg0, arg1, arg2):
t1 = torch.tanh(arg0)
t2 = t1.clone()
t2.fill_(arg1.item())
t3 = torch.clamp(t2, 0, arg2.size(0) - 1).to(torch.long)
return torch.nn.functional.embedding(t3, arg2)
arg0 = torch.randint(0, 1000, [47], dtype=torch.int64, device=self.device)
arg1 = torch.randint(0, 1000, [], dtype=torch.int64, device=self.device)
arg2 = torch.rand([256, 88], dtype=torch.float16, device=self.device)
cfn = torch.compile(fullgraph=True, dynamic=True)(fn)
self.assertEqual(fn(arg0, arg1, arg2), cfn(arg0, arg1, arg2))
# from GPT2ForSequenceClassification
@skip_if_gpu_halide
def test_index_tensor(self):

View File

@ -592,7 +592,6 @@ def graph_executor(
proto = onnxscript_function.to_function_proto()
ir_function = ir.serde.deserialize_function(proto)
onnx_model.functions[identifier] = ir_function
_ir_passes.add_torchlib_common_imports(onnx_model, opset_version=opset_version)
_ir_passes.add_opset_imports(onnx_model)
# Make sure the model is valid
model_proto = ir.to_proto(onnx_model)

View File

@ -384,143 +384,6 @@ class TestTorchAutocast(TestCase):
with self.assertRaisesRegex(expected_exception=ValueError, expected_regex=msg):
torch.autocast(device_type=dev)
@skipIfTorchDynamo()
def test_autocast_nograd_caching_issue_158232(self):
"""
Regression test for issue #158232: autocast + no_grad incompatibility
When torch.no_grad() is nested inside torch.autocast(), the autocast cache
must not cache tensors created in the no_grad context, because they lack
gradient tracking. If cached, subsequent operations in gradient-enabled mode
would incorrectly use the no-gradient cached version.
Before fix: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
After fix: Should work correctly
"""
model = torch.nn.Linear(2, 2)
inp = torch.randn(8, 2)
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
# First forward pass in no_grad context (e.g., shape inference)
with torch.no_grad():
out1 = model(inp)
self.assertFalse(
out1.requires_grad, "Output in no_grad should not require grad"
)
# Second forward pass with gradients enabled (e.g., training)
out2 = model(inp)
self.assertTrue(
out2.requires_grad,
"Output should require gradients after exiting no_grad",
)
self.assertIsNotNone(
out2.grad_fn, "Output should have grad_fn after exiting no_grad"
)
# Backward pass should work
loss = out2.mean()
loss.backward()
# Verify gradients were computed
self.assertIsNotNone(model.weight.grad)
self.assertIsNotNone(model.bias.grad)
@skipIfTorchDynamo()
def test_autocast_inference_mode_interaction(self):
"""
Test that autocast works correctly with torch.inference_mode()
InferenceMode is a stricter version of no_grad that provides additional
performance optimizations. Verify it doesn't break with autocast.
"""
model = torch.nn.Linear(2, 2)
inp = torch.randn(8, 2)
# Test 1: inference_mode inside autocast
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
with torch.inference_mode():
out1 = model(inp)
self.assertFalse(out1.requires_grad)
self.assertEqual(out1.dtype, torch.bfloat16)
# After exiting inference_mode, gradients should work
out2 = model(inp)
self.assertTrue(out2.requires_grad)
out2.mean().backward()
# Test 2: autocast inside inference_mode
with torch.inference_mode():
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
out = model(inp)
self.assertFalse(out.requires_grad)
self.assertEqual(out.dtype, torch.bfloat16)
def test_autocast_caching_still_works_with_gradients(self):
"""
Verify that autocast caching still functions correctly when gradients ARE enabled.
This test ensures the fix for #158232 didn't break normal caching behavior.
We can't directly observe cache hits, but we verify that repeated operations
with gradients enabled work correctly.
"""
model = torch.nn.Linear(2, 2)
inp = torch.randn(8, 2)
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
# Multiple forward passes with gradients enabled
out1 = model(inp)
out2 = model(inp)
out3 = model(inp)
# All should have gradients
self.assertTrue(out1.requires_grad)
self.assertTrue(out2.requires_grad)
self.assertTrue(out3.requires_grad)
# All should have grad_fn
self.assertIsNotNone(out1.grad_fn)
self.assertIsNotNone(out2.grad_fn)
self.assertIsNotNone(out3.grad_fn)
# Backward should work on all
out1.mean().backward(retain_graph=True)
out2.mean().backward(retain_graph=True)
out3.mean().backward()
@skipIfTorchDynamo()
def test_autocast_mixed_grad_contexts(self):
"""
Test complex nesting of gradient contexts within autocast.
This ensures the gradient mode check works correctly across
multiple transitions between gradient-enabled and disabled states.
"""
model = torch.nn.Linear(2, 2)
inp = torch.randn(8, 2)
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
# Pass 1: no_grad
with torch.no_grad():
out1 = model(inp)
self.assertFalse(out1.requires_grad)
# Pass 2: gradients enabled
out2 = model(inp)
self.assertTrue(out2.requires_grad)
# Pass 3: no_grad again
with torch.no_grad():
out3 = model(inp)
self.assertFalse(out3.requires_grad)
# Pass 4: gradients enabled again
out4 = model(inp)
self.assertTrue(out4.requires_grad)
# Backward on gradient-enabled outputs
(out2.mean() + out4.mean()).backward()
if __name__ == "__main__":
run_tests()

View File

@ -56,14 +56,15 @@ if TEST_CUDA:
# Protects against includes accidentally setting the default dtype
assert torch.get_default_dtype() is torch.float32
def xfailIfSM100OrLaterAndCondition(condition_fn):
def xfailIfSM100OrLaterNonRTXAndCondition(condition_fn):
"""
Conditionally xfail tests on SM100+ based on a condition function.
Conditionally xfail tests on SM100+ datacenter SKUs based on a condition function.
The condition function receives the test parameters dict and returns True to xfail.
"""
computeCapabilityCheck = SM100OrLater and torch.cuda.get_device_capability()[0] != 12
return decorateIf(
unittest.expectedFailure,
lambda params: SM100OrLater and condition_fn(params)
lambda params: computeCapabilityCheck and condition_fn(params)
)
@ -163,7 +164,7 @@ class TestMatmulCuda(InductorTestCase):
self.cublas_addmm(size, dtype, False)
@onlyCUDA
@xfailIfSM100OrLaterAndCondition(lambda params: params.get('dtype') == torch.bfloat16 and params.get('size') == 10000)
@xfailIfSM100OrLaterNonRTXAndCondition(lambda params: params.get('dtype') == torch.bfloat16 and params.get('size') == 10000)
# imported 'tol' as 'xtol' to avoid aliasing in code above
@toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1),
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})

View File

@ -1,5 +1,5 @@
import functools
from typing import Callable
from collections.abc import Callable
from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo as NFWDI
from torchgen.context import native_function_manager

View File

@ -36,7 +36,7 @@ from __future__ import annotations
import itertools
import re
from collections import defaultdict
from typing import Callable, TYPE_CHECKING
from typing import TYPE_CHECKING
import yaml
@ -77,7 +77,7 @@ from .gen_trace_type import should_trace
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from collections.abc import Callable, Iterable, Sequence
#

View File

@ -29,7 +29,7 @@
from __future__ import annotations
import re
from typing import Callable, TYPE_CHECKING
from typing import TYPE_CHECKING
from torchgen.api import cpp
from torchgen.api.autograd import (
@ -106,7 +106,7 @@ from .gen_trace_type import (
if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Callable, Sequence
# We don't set or modify grad_fn on these methods. Generally, they return

View File

@ -196,7 +196,7 @@ class FuzzTemplate:
class DefaultFuzzTemplate(FuzzTemplate):
def __init__(self):
from torchfuzz.checks import EagerVsFullGraphDynamicCompileWithNumericsCheck
from torchfuzz.checks import EagerVsFullGraphDynamicCompileCheck
super().__init__(
supported_ops=[
@ -236,7 +236,7 @@ class DefaultFuzzTemplate(FuzzTemplate):
# Regularization
"torch.nn.functional.dropout",
],
check=EagerVsFullGraphDynamicCompileWithNumericsCheck(),
check=EagerVsFullGraphDynamicCompileCheck(),
)
def spec_distribution(self):

View File

@ -241,7 +241,7 @@ if __name__ == "__main__":
import argparse
try:
from multi_process_fuzzer import run_multi_process_fuzzer
from multi_process_fuzzer import run_multi_process_fuzzer, run_until_failure
except ImportError:
# If importing as a module fails, import from the same directory
import os
@ -249,7 +249,7 @@ if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir)
from multi_process_fuzzer import run_multi_process_fuzzer
from multi_process_fuzzer import run_multi_process_fuzzer, run_until_failure
# Set up command-line argument parsing
parser = argparse.ArgumentParser(
@ -296,6 +296,11 @@ if __name__ == "__main__":
action="store_true",
help="Print detailed output for all runs (not just failures)",
)
parser.add_argument(
"--stop-at-first-failure",
action="store_true",
help="Pick a random seed and keep iterating until finding a failure (exits with non-zero code)",
)
# Legacy arguments
parser.add_argument(
@ -337,6 +342,30 @@ if __name__ == "__main__":
supported_ops=parsed_supported_ops,
op_weights=(parsed_weights if parsed_weights else None),
)
elif args.stop_at_first_failure:
# Stop-at-first-failure mode
# Default number of processes
if args.processes is None:
cpu_count = mp.cpu_count()
args.processes = max(1, min(16, int(cpu_count * 0.75)))
if args.processes < 1:
print("❌ Error: Number of processes must be at least 1")
sys.exit(1)
try:
run_until_failure(
num_processes=args.processes,
verbose=args.verbose,
template=args.template,
supported_ops=args.supported_ops,
)
except Exception as e:
print(f"❌ Unexpected error: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(1)
elif args.start is not None or args.count is not None:
# Multi-process fuzzing mode
if args.start is None:

View File

@ -66,6 +66,12 @@ IGNORE_PATTERNS: list[re.Pattern] = [
re.compile(
r"torch\._inductor\.exc\.InductorError: CppCompileError: C\+\+ compile error"
), # https://github.com/pytorch/pytorch/issues/164686
re.compile(
r"\.item\(\) # dtype="
), # https://github.com/pytorch/pytorch/issues/164725
re.compile(
r"dimensionality of sizes \(0\) must match dimensionality of strides \(1\)"
), # https://github.com/pytorch/pytorch/issues/164814
# Add more patterns here as needed, e.g.:
# re.compile(r"Some other error message"),
]
@ -516,3 +522,143 @@ def _print_operation_distribution(results: list[FuzzerResult]) -> None:
persist_print(
"\n📊 No operation statistics collected (no successful runs with stats)"
)
def run_until_failure(
num_processes: Optional[int] = None,
verbose: bool = False,
template: str = "default",
supported_ops: Optional[str] = None,
) -> None:
"""
Run the multi-process fuzzer with a random starting seed, iterating until a failure is found.
Args:
num_processes: Number of worker processes to use
verbose: Whether to print detailed output
template: The template to use for code generation
supported_ops: Comma-separated ops string with optional weights
Returns:
Exits with non-zero code when a failure is found
"""
import random
# Pick a random seed to start from
initial_seed = random.randint(0, 2**31 - 1)
persist_print(
f"🎲 Starting continuous fuzzing with random initial seed: {initial_seed}"
)
persist_print(f"🚀 Using {num_processes} processes")
persist_print(
f"🔧 Command template: python fuzzer.py --seed {{seed}} --template {template}"
)
persist_print("🎯 Running until first failure is found...")
persist_print("=" * 60)
start_time = time.time()
current_seed = initial_seed
total_successful = 0
total_ignored = 0
batch_size = 100 # Process seeds in batches of 100
try:
while True:
# Process a batch of seeds
seeds = list(range(current_seed, current_seed + batch_size))
with mp.Pool(processes=num_processes) as pool:
future_results = []
for seed in seeds:
future = pool.apply_async(
run_fuzzer_with_seed, (seed, template, supported_ops)
)
future_results.append((seed, future))
# Set up progress bar for this batch
if HAS_TQDM:
from tqdm import tqdm
pbar = tqdm(
total=len(seeds),
desc=f"Batch starting at seed {current_seed}",
file=sys.stdout,
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}] ✅/🚫={postfix}",
dynamic_ncols=True,
)
pbar.set_postfix_str(f"{total_successful}/{total_ignored}")
def write_func(msg):
pbar.write(msg)
else:
pbar = None
# Collect results as they complete
for seed, future in future_results:
result: FuzzerResult = future.get()
if result.ignored_pattern_idx != -1:
total_ignored += 1
if result.success:
total_successful += 1
elif result.ignored_pattern_idx == -1:
# Found a failure that is not ignored!
if HAS_TQDM and pbar:
pbar.close()
elapsed = time.time() - start_time
persist_print("\n" + "=" * 60)
persist_print("🎯 FAILURE FOUND!")
persist_print("=" * 60)
persist_print(f"❌ Failing seed: {result.seed}")
persist_print(
f"⏱️ Duration for this seed: {result.duration:.2f}s"
)
persist_print(f"⏱️ Total time elapsed: {elapsed:.2f}s")
persist_print(f"✅ Successful seeds tested: {total_successful}")
persist_print(f"🚫 Ignored seeds: {total_ignored}")
persist_print(
f"📊 Total seeds tested: {total_successful + total_ignored + 1}"
)
persist_print("\n💥 Failure output:")
persist_print("-" * 60)
print_output_lines(result.output, persist_print)
persist_print("-" * 60)
persist_print(
f"\n🔄 Reproduce with: python fuzzer.py --seed {result.seed} --template {template}"
)
# Exit with non-zero code
sys.exit(1)
# Update progress bar
if HAS_TQDM and pbar:
pbar.set_postfix_str(f"{total_successful}/{total_ignored}")
pbar.update(1)
elif verbose:
status_emoji = "" if result.success else "🚫"
persist_print(f"Seed {result.seed}: {status_emoji}")
# Close progress bar for this batch
if HAS_TQDM and pbar:
pbar.close()
# Move to next batch
current_seed += batch_size
except KeyboardInterrupt:
persist_print("\n🛑 Interrupted by user (Ctrl+C)")
elapsed = time.time() - start_time
persist_print("=" * 60)
persist_print("📈 SUMMARY (interrupted)")
persist_print("=" * 60)
persist_print(f"⏱️ Total time: {elapsed:.2f}s")
persist_print(f"✅ Successful seeds: {total_successful}")
persist_print(f"🚫 Ignored seeds: {total_ignored}")
persist_print(f"📊 Total seeds tested: {total_successful + total_ignored}")
persist_print(
f"⚡ Throughput: {((total_successful + total_ignored) / (elapsed / 3600)):.2f} seeds/hr"
)
sys.exit(130)

View File

@ -5,7 +5,8 @@
# LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Callable, Optional
from collections.abc import Callable
from typing import Any, Optional
class FlightRecorderLogger:

View File

@ -4,12 +4,16 @@ from __future__ import annotations
import json
import os
from typing import Any, Callable, cast
from typing import Any, cast, TYPE_CHECKING
from urllib.error import HTTPError
from urllib.parse import quote
from urllib.request import Request, urlopen
if TYPE_CHECKING:
from collections.abc import Callable
def gh_fetch_url_and_headers(
url: str,
*,

View File

@ -5,7 +5,7 @@ import json
import sys
from functools import cached_property
from pathlib import Path
from typing import Any, Callable, TYPE_CHECKING
from typing import Any, TYPE_CHECKING
_FILE = Path(__file__).absolute()
@ -18,7 +18,7 @@ else:
import _linter
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from collections.abc import Callable, Iterator, Sequence
GRANDFATHER_LIST = _FILE.parent / "docstring_linter-grandfather.json"

View File

@ -22,11 +22,15 @@ import os
import re
from enum import Enum
from pathlib import Path
from typing import Any, Callable, NamedTuple, Optional
from typing import Any, NamedTuple, Optional, TYPE_CHECKING
from yaml import load
if TYPE_CHECKING:
from collections.abc import Callable
# Safely load fast C Yaml loader/dumper if they are available
try:
from yaml import CSafeLoader as Loader

View File

@ -65,10 +65,11 @@ import textwrap
import time
import uuid
from ast import literal_eval
from collections.abc import Callable
from datetime import datetime
from pathlib import Path
from platform import system as platform_system
from typing import Any, Callable, cast, NamedTuple, TYPE_CHECKING, TypeVar
from typing import Any, cast, NamedTuple, TYPE_CHECKING, TypeVar
if TYPE_CHECKING:

View File

@ -7,10 +7,14 @@ import json
import os
import shutil
from pathlib import Path
from typing import Any, Callable, cast
from typing import Any, cast, TYPE_CHECKING
from urllib.request import urlopen
if TYPE_CHECKING:
from collections.abc import Callable
REPO_ROOT = Path(__file__).resolve().parents[2]

View File

@ -6,13 +6,17 @@ import json
import os
import time
import urllib.parse
from typing import Any, Callable, cast
from typing import Any, cast, TYPE_CHECKING
from urllib.error import HTTPError
from urllib.request import Request, urlopen
from tools.stats.upload_stats_lib import upload_to_s3
if TYPE_CHECKING:
from collections.abc import Callable
FILTER_OUT_USERS = {
"pytorchmergebot",
"facebook-github-bot",

View File

@ -9,12 +9,16 @@ import time
import zipfile
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable, cast, Optional
from typing import Any, cast, Optional, TYPE_CHECKING
import boto3 # type: ignore[import]
import requests
if TYPE_CHECKING:
from collections.abc import Callable
PYTORCH_REPO = "https://api.github.com/repos/pytorch/pytorch"

View File

@ -107,6 +107,7 @@ TESTS = discover_tests(
"lazy/test_meta_kernel",
"lazy/test_extract_compiled_graph",
"test/inductor/test_aot_inductor_utils",
"inductor/test_aoti_cross_compile_windows",
"onnx/test_onnxscript_no_runtime",
"onnx/test_pytorch_onnx_onnxruntime_cuda",
"onnx/test_models",

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from collections import defaultdict
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable
from typing import Any, TYPE_CHECKING
from warnings import warn
from tools.testing.target_determination.heuristics.interface import (
@ -17,6 +17,10 @@ from tools.testing.target_determination.heuristics.utils import (
from tools.testing.test_run import TestRun
if TYPE_CHECKING:
from collections.abc import Callable
REPO_ROOT = Path(__file__).parents[3]
keyword_synonyms: dict[str, list[str]] = {

View File

@ -4,7 +4,7 @@ import math
import os
import subprocess
from pathlib import Path
from typing import Callable, TYPE_CHECKING
from typing import TYPE_CHECKING
from tools.stats.import_test_stats import get_disabled_tests
from tools.testing.test_run import ShardedTest, TestRun
@ -19,7 +19,7 @@ except ImportError:
if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Callable, Sequence
REPO_ROOT = Path(__file__).resolve().parents[2]

View File

@ -1442,6 +1442,7 @@ _has_cuda: _bool
_has_magma: _bool
_has_xpu: _bool
_has_mkldnn: _bool
_has_mkldnn_acl: _bool
_has_cudnn: _bool
_has_cusparselt: _bool
has_spectral: _bool

View File

@ -1376,6 +1376,7 @@ class CompilationMetrics:
recompile_user_contexts: Optional[set[str]] = None
inline_inbuilt_nn_modules_candidate: Optional[bool] = False
pytorch_version: Optional[str] = None
inductor_provenance: Optional[str] = None
@classmethod
def create(cls, metrics: dict[str, Any]) -> CompilationMetrics:

View File

@ -42,7 +42,12 @@ import torch.distributed as dist
from torch import SymInt, Tensor
from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.exc import SkipFrame
from torch._dynamo.utils import CompileEventLogger, counters, dynamo_timed
from torch._dynamo.utils import (
CompileEventLogger,
counters,
dynamo_timed,
get_metrics_context,
)
from torch._inductor import config, exc, metrics
from torch._inductor.codegen.common import (
custom_backend_codegen_configs,
@ -339,7 +344,7 @@ def sha256_hash(data: bytes) -> str:
return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower()
def code_hash(code: Union[str, bytes], extra: Union[str, bytes] = "") -> str:
def code_hash(code: str | bytes, extra: str | bytes = "") -> str:
hashing_str = code if isinstance(code, bytes) else code.encode("utf-8")
if extra:
extra_b = extra if isinstance(extra, bytes) else extra.encode("utf-8")
@ -361,9 +366,7 @@ def get_path(
return basename, subdir, path
def get_hash(
content: Union[str, bytes], extra: str = "", hash_type: str = "code"
) -> str:
def get_hash(content: str | bytes, extra: str = "", hash_type: str = "code") -> str:
if hash_type in {"amdgcn", "code", "ptx", "spv"}:
return code_hash(content, extra)
if hash_type in {"cubin", "hsaco", "spv"}:
@ -409,7 +412,7 @@ class WritableTempFile:
def write(
content: Union[str, bytes],
content: str | bytes,
extension: str,
extra: str = "",
hash_type: str = "code",
@ -436,7 +439,7 @@ def write_text(text: str) -> str:
def write_atomic(
path_: str,
content: Union[str, bytes],
content: str | bytes,
make_dirs: bool = False,
encode_utf_8: bool = False,
) -> None:
@ -547,7 +550,7 @@ class FxGraphCachePickler(pickle.Pickler):
def _reduce_tensor(
self, t: Tensor
) -> tuple[Callable[[T], T], tuple[Union[TensorMetadata, TensorMetadataAndValues]]]:
) -> tuple[Callable[[T], T], tuple[TensorMetadata | TensorMetadataAndValues]]:
"""
Custom reducer to pickle Tensors. If we see tensors, we know they're constants
stored as attributes on the GraphModule.
@ -943,7 +946,7 @@ class FxGraphHashDetails:
raise AssertionError(f"unknown config type: {str(type(custom_pass))}")
def _get_custom_pass_detail(
self, custom_pass: Union[CustomGraphPassType, CustomGraphModulePass]
self, custom_pass: CustomGraphPassType | CustomGraphModulePass
) -> Any | None:
if not custom_pass:
return None
@ -1058,7 +1061,7 @@ class GuardedCache(Generic[T]):
key: str,
local: bool,
remote_cache: RemoteCache[JsonDataTy] | None,
evaluate_guards: Callable[[str, Union[list[int], list[torch.SymInt]]], bool],
evaluate_guards: Callable[[str, list[int] | list[torch.SymInt]], bool],
hints: list[int],
) -> tuple[T | None, bytes | None, dict[str, str]]:
"""
@ -1283,6 +1286,10 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
},
payload_fn=lambda: graph.inductor_provenance_stack_traces_str,
)
if get_metrics_context().in_progress():
get_metrics_context().add_to_set(
"inductor_provenance", graph.inductor_provenance_stack_traces_str
)
return graph, cache_info
@staticmethod
@ -1292,7 +1299,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
local: bool,
remote_cache: RemoteCache[JsonDataTy] | None,
constants: CompiledFxGraphConstants,
evaluate_guards: Callable[[str, Union[list[int], list[torch.SymInt]]], bool]
evaluate_guards: Callable[[str, list[int] | list[torch.SymInt]], bool]
| None = None,
) -> tuple[CompiledFxGraph | None, dict[str, Any]]:
"""
@ -1543,7 +1550,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
remote_cache: RemoteCache[JsonDataTy] | None,
is_backward: bool,
constants: CompiledFxGraphConstants,
evaluate_guards: Callable[[str, Union[list[int], list[torch.SymInt]]], bool]
evaluate_guards: Callable[[str, list[int] | list[torch.SymInt]], bool]
| None = None,
) -> tuple[CompiledFxGraph | None, dict[str, Any]]:
"""
@ -1723,12 +1730,12 @@ class AotCodeCompiler:
*,
device_type: str,
additional_files: list[str],
) -> Union[list[Union[str, Weights]], str]:
) -> list[Union[str, Weights]] | str:
"""
Returns the .so path, or returns a list of files that were generated if
config.aot_inductor.package=True.
"""
generated_files: list[Union[str, Weights]] = additional_files # type: ignore[assignment]
generated_files: list[str | Weights] = additional_files # type: ignore[assignment]
_set_gpu_runtime_env() # cpp_extension consults the env
@ -2342,7 +2349,7 @@ end
f.write(json.dumps(qual_name_to_id))
generated_files.append(constants_config_json)
gpu_codecache: Union[ROCmCodeCache, CUDACodeCache] = (
gpu_codecache: ROCmCodeCache | CUDACodeCache = (
ROCmCodeCache() if torch.version.hip else CUDACodeCache()
)
gpu_kernels_o = gpu_codecache.aot_kernels_o.copy()
@ -2555,7 +2562,7 @@ end
_libgomp: CDLL | None = None
def custom_op_wrapper(op: str, *args: Any) -> Union[list[c_void_p], c_void_p, None]:
def custom_op_wrapper(op: str, *args: Any) -> list[c_void_p] | c_void_p | None:
# This function will be called from generated cpp wrapper code in the JIT mode.
# Because tensors will be passed in as AtenTensorHandle, we need to explicitly convert them.
def convert_arg(arg: Any) -> Any:
@ -2698,16 +2705,16 @@ class CppCodeCache:
"""Compiles and caches C++ libraries. Users of this class supply the source code to
be compiled, while compilation flags are set by CppBuilder."""
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
cache: dict[str, Callable[[], CDLL | ModuleType]] = {}
cache_clear = staticmethod(cache.clear)
cpp_compile_command_flags: dict[str, Any] = {}
@staticmethod
def _load_library_inner(path: str, key: str) -> Union[CDLL, ModuleType]:
def _load_library_inner(path: str, key: str) -> CDLL | ModuleType:
return cdll.LoadLibrary(path)
@classmethod
def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]:
def _load_library(cls, path: str, key: str) -> CDLL | ModuleType:
try:
result = cls._load_library_inner(path, key)
result.key = key # type: ignore[union-attr]
@ -2910,7 +2917,7 @@ def _worker_compile_cpp(
# Customized Python binding for cpp kernels
@clear_on_fresh_cache
class CppPythonBindingsCodeCache(CppCodeCache):
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
cache: dict[str, Callable[[], CDLL | ModuleType]] = {}
cache_clear = staticmethod(cache.clear)
cpp_compile_command_flags = {
# kernels have no dependency on libtorch
@ -3092,7 +3099,7 @@ class CppPythonBindingsCodeCache(CppCodeCache):
@clear_on_fresh_cache
class CppWrapperCodeCache(CppPythonBindingsCodeCache):
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
cache: dict[str, Callable[[], CDLL | ModuleType]] = {}
cache_clear = staticmethod(cache.clear)
cpp_compile_command_flags = {
"include_pytorch": True,
@ -3161,7 +3168,7 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache):
@clear_on_fresh_cache
class HalideCodeCache(CppPythonBindingsCodeCache):
cache: dict[str, Callable[[], Union[ModuleType, CDLL]]] = {}
cache: dict[str, Callable[[], ModuleType | CDLL]] = {}
cache_clear = staticmethod(cache.clear)
_standalone_runtime_path: str | None = None
prefix = textwrap.dedent(

View File

@ -950,6 +950,7 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
or _all_in_parens(string)
):
# don't put extra parens for strings that are already wrapped in parens
# pyrefly: ignore # bad-return
return string
return f"({string})"
@ -1736,7 +1737,9 @@ class KernelArgs:
)
)
for outer, inner in chain(
self.input_buffers.items(), self.output_buffers.items()
# pyrefly: ignore # bad-argument-type
self.input_buffers.items(),
self.output_buffers.items(),
):
if outer in self.inplace_buffers or isinstance(inner, RemovedArg):
continue
@ -2047,6 +2050,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
) -> None:
super().__init__()
if increase_kernel_count:
# pyrefly: ignore # bad-assignment
metrics.generated_kernel_count += 1
self.args = args or KernelArgs()
self.loads = IndentedBuffer()
@ -2113,6 +2117,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
self.compute = compute
self.stores = stores
self.cse = cse
# pyrefly: ignore # unbound-name
if disallow_stores:
assert not sb, "unexpected store inside swap_buffers"
@ -2384,6 +2389,7 @@ class KernelTemplate:
class DetailedTemplateSyntaxError(TemplateSyntaxError):
def __init__(self, original_error: TemplateSyntaxError) -> None:
super().__init__(
# pyrefly: ignore # bad-argument-type
original_error.message,
original_error.lineno,
original_error.name,
@ -2395,6 +2401,7 @@ class KernelTemplate:
error_info = f"Error in template at line {self.lineno}\n"
error_info += f"Error message: {self.message}\n"
if hasattr(self.original_error, "source"):
# pyrefly: ignore # missing-attribute
lines = self.original_error.source.split("\n")
error_info += "Context:\n"
start = max(0, self.lineno - 2)

View File

@ -504,6 +504,7 @@ class OuterLoopFusedSchedulerNode(FusedSchedulerNode):
if any(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)):
return cls(
node1.scheduler,
# pyrefly: ignore # bad-argument-type
(
list(node1.get_outer_nodes())
if type(node1) is OuterLoopFusedSchedulerNode
@ -1716,6 +1717,7 @@ class CppVecOverrides(CppOverrides):
body_vec_var.dtype = dtype
other_vec_var.dtype = dtype
overrides: type[Union[CppOverrides, CppVecOverrides]] = (
# pyrefly: ignore # bad-assignment
V.kernel.overrides
) # type: ignore[has-type]
code.writeline(
@ -1759,6 +1761,7 @@ class CppVecOverrides(CppOverrides):
csevar = V.kernel._load_or_store_non_contiguous( # type: ignore[assignment]
None, index, dtype, V.kernel.compute
)
# pyrefly: ignore # missing-attribute
csevar.update_on_args("index_expr", (expr, dtype), {})
return csevar
@ -2036,6 +2039,7 @@ class CppKernel(Kernel):
# mask's dtype should be bool
mask.dtype = torch.bool
# pyrefly: ignore # bad-assignment
self._load_mask = mask
try:
yield mask
@ -2363,6 +2367,7 @@ class CppKernel(Kernel):
sympy_index_symbol_with_prefix(SymT.XBLOCK, n)
for n in range(len(self.ranges))
]
# pyrefly: ignore # bad-assignment
self.reduction_depth = len(lengths)
return (
self.itervars[: self.reduction_depth],
@ -2610,7 +2615,9 @@ class CppKernel(Kernel):
and end == self.ranges[var_id]
):
end = 1
# pyrefly: ignore # bad-argument-type
conditions.append(f"{var} >= {cexpr_index(start)}")
# pyrefly: ignore # bad-argument-type
conditions.append(f"{var} < {cexpr_index(end)}")
return True
@ -4085,6 +4092,7 @@ class CppKernelProxy(CppKernel):
and (dt := get_output_dtype(_node)) in DTYPE_LOWP_FP
):
# No need to promote to float if all users are ops that accepts lowp fp input
# pyrefly: ignore # bad-argument-type
if all(is_lowp_fp_sink(user, dt) for user in _node.users):
continue
ops = _node.args[0]
@ -4095,12 +4103,14 @@ class CppKernelProxy(CppKernel):
_node.replace_all_uses_with(
to_type_node, lambda n: n is not to_type_node
)
# pyrefly: ignore # bad-assignment
metrics.cpp_to_dtype_count += 1
elif (
_node.target == "store"
and (dt := get_input_dtype(_node)) in DTYPE_LOWP_FP
):
ops, name, _, value_var, _ = _node.args
# pyrefly: ignore # bad-argument-type
if is_lowp_fp_source_no_promote(value_var, dt):
continue
dtype = V.graph.get_dtype(name)
@ -4109,6 +4119,7 @@ class CppKernelProxy(CppKernel):
"to_dtype", args=(ops, value_var, dtype)
)
_node.replace_input_with(value_var, to_type_node)
# pyrefly: ignore # bad-assignment
metrics.cpp_to_dtype_count += 1
elif _node.target == "reduction":
(
@ -4178,6 +4189,7 @@ class CppKernelProxy(CppKernel):
"to_dtype", args=(ops, value_var, src_dtype)
)
_node.replace_input_with(value_var, to_type_node)
# pyrefly: ignore # bad-assignment
metrics.cpp_to_dtype_count += 1
# to_dtype_bitcast act as a lowp fp source:
@ -4196,6 +4208,7 @@ class CppKernelProxy(CppKernel):
_node.replace_all_uses_with(
to_type_node, lambda n: n is not to_type_node
)
# pyrefly: ignore # bad-assignment
metrics.cpp_to_dtype_count += 1
def eliminate_to_dtype(sub_graph: torch.fx.Graph):
@ -4289,6 +4302,7 @@ class CppKernelProxy(CppKernel):
with kernel_group.new_kernel(cls, *args) as kernel:
# Ugly hack to maintain the metrics kernel count since
# we only count in CppKernelProxy, not those contained in it
# pyrefly: ignore # bad-assignment
metrics.generated_kernel_count -= 1
run(kernel)
@ -4360,6 +4374,7 @@ class CppKernelProxy(CppKernel):
)
if len(tiling_indices) == 1:
# pyrefly: ignore # bad-assignment
metrics.generated_cpp_vec_kernel_count += 1
loop = self.loop_nest.tile(tiling_indices[0], factor=tiling_factors[0])
vec_kernel = codegen_kernel(
@ -4386,6 +4401,7 @@ class CppKernelProxy(CppKernel):
and tiling_factors[0] == tiling_factors[1]
)
# pyrefly: ignore # bad-assignment
metrics.generated_cpp_vec_kernel_count += 2
outer_loop = self.loop_nest.tile(
tiling_indices[0], factor=tiling_factors[0]
@ -5134,10 +5150,12 @@ class CppScheduling(BaseScheduling):
contiguous_index_expr = 0
stride = 1
for var, range in reversed(
# pyrefly: ignore # missing-attribute
scheduler_node._body.var_ranges.items()
):
contiguous_index_expr += stride * var
stride *= range
# pyrefly: ignore # missing-attribute
write_index_expr = scheduler_node._body.get_write_expr(
scheduler_buffer.get_name()
)
@ -5206,6 +5224,7 @@ class CppScheduling(BaseScheduling):
)
local_buffers.append(local_buffer_used)
local_to_global_buffers[local_buffer_used.name] = [] # type: ignore[index]
# pyrefly: ignore # index-error
local_to_global_buffers[local_buffer_used.name].append(
global_buffer,
)
@ -5450,6 +5469,7 @@ class CppScheduling(BaseScheduling):
wrapper = V.graph.wrapper_code
debug_handle = set_kernel_post_grad_provenance_tracing(
node_schedule, # type: ignore[arg-type]
# pyrefly: ignore # bad-argument-type
kernel_name,
)
wrapper.write_provenance_debug_handle(kernel_name, debug_handle)
@ -5771,6 +5791,7 @@ class LoopNest:
loop = self.loops[par_depth.start_depth]
loop.parallel = par_depth.parallel_depth
if loop.is_reduction:
# pyrefly: ignore # bad-assignment
metrics.parallel_reduction_count += 1
for i in range(par_depth.start_depth + 1, par_depth.parallel_depth):
self.loops[i].collapsed = True

View File

@ -396,12 +396,15 @@ def transpose_w(W: _T, trans_w: bool) -> _T:
if isinstance(W, ir.IRNode):
if trans_w:
if not isinstance(W, ir.TensorBox):
# pyrefly: ignore # bad-assignment
W = ir.TensorBox(W)
W = L.permute(W, [1, 0])
else:
if trans_w:
assert isinstance(W, torch.Tensor)
# pyrefly: ignore # bad-assignment
W = W.transpose(0, 1)
# pyrefly: ignore # bad-return
return W
@ -412,12 +415,15 @@ def expand_bias(B: Optional[_T], X: _T) -> Optional[_T]:
if B is not None:
if isinstance(B, ir.IRNode):
if not isinstance(B, ir.TensorBox):
# pyrefly: ignore # bad-assignment
B = ir.TensorBox(B)
assert hasattr(X, "get_size")
# pyrefly: ignore # missing-attribute
B = L.expand(B, (X.get_size()[0], B.get_size()[-1]))
else:
assert isinstance(B, torch.Tensor)
assert isinstance(X, torch.Tensor)
# pyrefly: ignore # bad-assignment
B = B.expand(X.shape[0], B.shape[-1])
return B
@ -1043,6 +1049,7 @@ class CppGemmTemplate(CppTemplate):
return cls.prep_weight(
new_inputs,
new_layout,
# pyrefly: ignore # bad-argument-type
micro_gemm,
pre_block_weights,
use_int8_fast_compensation_path,
@ -1066,6 +1073,7 @@ class CppGemmTemplate(CppTemplate):
new_input_nodes, _ = cls.prep_weight(
new_input_nodes,
new_layout,
# pyrefly: ignore # bad-argument-type
micro_gemm,
pre_block_weights,
use_int8_fast_compensation_path,
@ -1470,7 +1478,9 @@ class CppGemmTemplate(CppTemplate):
assert isinstance(template_buffer, ir.IRNode)
gemm_output_name = f"{template_buffer.get_name()}_GemmOut"
gemm_output_buffer = ir.Buffer(
name=gemm_output_name, layout=template_buffer.layout
# pyrefly: ignore # missing-attribute
name=gemm_output_name,
layout=template_buffer.layout,
)
current_input_buffer = gemm_output_buffer
for i, creator in enumerate(epilogue_creators):
@ -1481,6 +1491,7 @@ class CppGemmTemplate(CppTemplate):
epilogues.append(
ir.ComputedBuffer(
name=buffer_name,
# pyrefly: ignore # missing-attribute
layout=template_buffer.layout,
data=creator(current_input_buffer),
)
@ -1490,7 +1501,9 @@ class CppGemmTemplate(CppTemplate):
reindexers.append(None)
if i < len(epilogue_creators) - 1:
current_input_buffer = ir.Buffer(
name=buffer_name, layout=template_buffer.layout
# pyrefly: ignore # missing-attribute
name=buffer_name,
layout=template_buffer.layout,
)
assert isinstance(Y, (ir.Buffer, ir.ReinterpretView))
@ -1521,6 +1534,7 @@ class CppGemmTemplate(CppTemplate):
self.n,
self.k,
input_dtype=X.get_dtype(),
# pyrefly: ignore # missing-attribute
input2_dtype=W.get_dtype(),
output_dtype=output_dtype,
compute_dtype=compute_dtype,

View File

@ -183,12 +183,14 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
)
self.act_mapping = act_mapping
self.gemm_grouped_num = gemm_grouped_num
# pyrefly: ignore # bad-override
self.output_node: list[ir.Buffer] = [
ir.Buffer(name="buf_out" + str(idx), layout=layout)
for idx in range(gemm_grouped_num)
]
@classmethod
# pyrefly: ignore # bad-override
def add_choices(
cls,
choices: list[ChoiceCaller],
@ -231,6 +233,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
if isinstance(inputs[idx], torch.Tensor):
W = inputs[idx]
assert isinstance(W, torch.Tensor), "W must be a torch.Tensor"
# pyrefly: ignore # unsupported-operation
new_inputs[idx] = W.to_dense() if W.is_mkldnn else W
return new_inputs, layout_or_out
@ -246,8 +249,10 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
new_input = new_inputs[wgt_idx]
new_inputs[wgt_idx] = transpose_w(new_input, trans_w)
for bias_idx in range(bias_start_idx, len(new_inputs)):
# pyrefly: ignore # bad-argument-type
new_bias = expand_bias(new_inputs[bias_idx], X)
assert new_bias is not None
# pyrefly: ignore # unsupported-operation
new_inputs[bias_idx] = new_bias
return new_inputs, layout_or_out
@ -308,6 +313,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
W_tensor = []
for W_node in W_nodes:
assert W_node.get_name() in V.graph.constants
# pyrefly: ignore # bad-argument-type
W_tensor.append(V.graph.constants[W_node.get_name()])
new_input_nodes[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = (
W_tensor # type: ignore[assignment]
@ -324,6 +330,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
template_buffer.inputs[idx] = (
ir.InputsKernel.unwrap_storage_for_input(W_packed_constant)
)
# pyrefly: ignore # bad-return
return output
template = DataProcessorTemplateWrapper(
@ -362,6 +369,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
cur_idx = bias_start_idx
for inp_idx in range(self.gemm_grouped_num):
inp = None
# pyrefly: ignore # index-error
if self.has_bias[inp_idx]:
inp = self.input_nodes[cur_idx]
cur_idx += 1
@ -390,6 +398,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
self.n,
self.k,
input_dtype=X_list[0].get_dtype(),
# pyrefly: ignore # missing-attribute
input2_dtype=W_list[0].get_dtype(),
output_dtype=output_dtype,
compute_dtype=compute_dtype,
@ -427,6 +436,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
for x_idx in range(wgt_start_idx):
kernel_args["X" + str(x_idx)] = act_deduplicated[x_idx]
for w_idx in range(self.gemm_grouped_num):
# pyrefly: ignore # unsupported-operation
kernel_args["W" + str(w_idx)] = W_list[w_idx]
for inp_idx in range(self.gemm_grouped_num):
kernel_args["inp" + str(inp_idx)] = inp_list[inp_idx]

View File

@ -85,6 +85,7 @@ class CppTemplate(KernelTemplate):
bmreq = CppBenchmarkRequest(
kernel_name=kernel_name,
input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
# pyrefly: ignore # bad-argument-type
output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
extra_args=extra_args,
source_code=code,
@ -112,6 +113,7 @@ class CppTemplate(KernelTemplate):
kernel_hash_name,
self.name,
self.input_nodes,
# pyrefly: ignore # index-error
self.output_node[0].get_layout()
if isinstance(self.output_node, Iterable)
else self.output_node.get_layout(),

Some files were not shown because too many files have changed in this diff Show More