Commit Graph

459 Commits

Author SHA1 Message Date
481a57bc37 Support torch.compile rng selective activation checkpointing with cudagraph (#146878)
TODO:
- [x]  Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x]  Tests
- [x] handling of retain_graph
- [x] respect fallback random

Fix for https://github.com/pytorch/pytorch/issues/130123.

Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.

We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.

```
 ===== Forward graph 1 =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0);  fwd_rng_state_0 = None
        ...

 ===== Backward graph 1 =====
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0);  bwd_rng_state_0 = None
```

There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0

Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.

Other notes:

Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.

Questions for reviewers:

This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.

Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set

I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.

Edit: updated to be taken from randint()

Update: initializing rng states from torch.randint..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
2025-02-28 00:47:03 +00:00
17358ce778 Revert "Support torch.compile rng selective activation checkpointing with cudagraph (#146878)"
This reverts commit ad0c879e2203145f6d56df0b95af36822220ab8f.

Reverted https://github.com/pytorch/pytorch/pull/146878 on behalf of https://github.com/wdvr due to lint failure ([comment](https://github.com/pytorch/pytorch/pull/146878#issuecomment-2686767956))
2025-02-27 03:36:16 +00:00
ad0c879e22 Support torch.compile rng selective activation checkpointing with cudagraph (#146878)
TODO:
- [x]  Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x]  Tests
- [x] handling of retain_graph
- [x] respect fallback random

Fix for https://github.com/pytorch/pytorch/issues/130123.

Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.

We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.

```
 ===== Forward graph 1 =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0);  fwd_rng_state_0 = None
        ...

 ===== Backward graph 1 =====
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0);  bwd_rng_state_0 = None
```

There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0

Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.

Other notes:

Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.

Questions for reviewers:

This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.

Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set

I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.

Edit: updated to be taken from randint()

Update: initializing rng states from torch.randint..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
2025-02-27 02:08:29 +00:00
452315c84f Fix RuntimeError: value cannot be converted to type int64_t without overflow (#147492)
The exact call is coming from here:

78a94c9114/torch/_inductor/memory.py (L161)

I have no idea why this error is being thrown and what mode/modes might be failing for this

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147492
Approved by: https://github.com/eellison
2025-02-20 08:00:26 +00:00
db4ce78d46 PEP585: More UP006 fixes (#146392)
This should be the final PR before we can enable RUFF UP006.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146392
Approved by: https://github.com/justinchuby, https://github.com/albanD, https://github.com/Skylion007
2025-02-20 06:18:13 +00:00
44ee9ca593 [inductor] Add type annotations to _inductor/utils.py (#144108)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144108
Approved by: https://github.com/eellison
2025-02-15 23:13:41 +00:00
579b9f2ed9 [inductor] Better exception error messages for cache_on_self (#146652)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146652
Approved by: https://github.com/yanboliang
2025-02-07 21:22:21 +00:00
e9f6e273e7 [inductor] Add typing to common.CSE (#145993)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145993
Approved by: https://github.com/yanboliang
ghstack dependencies: #145916
2025-02-04 16:05:39 +00:00
d3c7e4bb9c Revert "[inductor] Add typing to common.CSE (#145993)"
This reverts commit 8c657ae4be55c6133307ad278c1740af5db133a7.

Reverted https://github.com/pytorch/pytorch/pull/145993 on behalf of https://github.com/atalman due to Sorry need to revert https://github.com/pytorch/pytorch/pull/145916 ([comment](https://github.com/pytorch/pytorch/pull/145993#issuecomment-2632712384))
2025-02-04 03:04:01 +00:00
8c657ae4be [inductor] Add typing to common.CSE (#145993)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145993
Approved by: https://github.com/yanboliang
ghstack dependencies: #145913, #145914, #145915, #145916
2025-02-01 16:34:18 +00:00
57d8278ab9 pickler for GraphModule (#141659)
Pickling GraphModule needs some special handling for wrapping things that normally can't be pickled - but async compile needs to pass them across a wire so we need to be able to serialize it - add some helpers to enable that.

Differential Revision: [D68921318](https://our.internmc.facebook.com/intern/diff/D68921318)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141659
Approved by: https://github.com/jamesjwu
2025-01-31 05:34:28 +00:00
2de53b3b65 Revert "pickler for GraphModule (#141659)"
This reverts commit c6ad08357bf8e766b5220bfb5cbbfdb2a4ec0ca5.

Reverted https://github.com/pytorch/pytorch/pull/141659 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally, please take a look at D68694181 for more details. ([comment](https://github.com/pytorch/pytorch/pull/141659#issuecomment-2617045120))
2025-01-27 22:39:30 +00:00
c6ad08357b pickler for GraphModule (#141659)
Pickling GraphModule needs some special handling for wrapping things that normally can't be pickled - but async compile needs to pass them across a wire so we need to be able to serialize it - add some helpers to enable that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141659
Approved by: https://github.com/jamesjwu
2025-01-26 19:29:13 +00:00
a40ead1fd6 Don't fail if fresh_inductor_cache fails to clean up its tmp dir. (#145513)
Summary: I see we have a test failure due to an error removing the tmp dir: https://github.com/pytorch/pytorch/issues/141761. Seems like we should not raise an exception for this case in general. Also, let's clean up the exception handling related to windows. The comment makes it sound like we want to specifically ignore failures cleaning up, but the current impl is swallowing all exceptions.

Fixes #141761

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145513
Approved by: https://github.com/eellison
2025-01-24 03:17:03 +00:00
b963ab5325 [inductor][1/N] triton support post-#5512, main components (#145051)
Triton commit 5220 adds tuple support in Triton (changing the indexing format in AttrsDescriptor) and commit 5512 replaces AttrsDescriptor with raw tuples. This is an initial PR to add support for Triton versions after commit 5512 landed.

The main changes in 5220 and 5512 that need to be supported:
* AttrsDescriptor() gets replaced with a raw dict. The raw dict has the format `{(TUPLES): [["tt.divisibility", 16]]}`, where `(TUPLES)` is a tuple of indices, e.g. `((0,), (1,), (3,))` to indicate that args 0, 1, and 3 are divisible by 16. These indices are, themselves, represented as tuples to support nested inputs (e.g. an argument that's a tuple), but support for tuples is not implemented right now.
* "signature" changes: the signature now contains _all_ args, including constexpr and constant args.
* ASTSource now takes "constexprs" instead of "constants" - for example, equal-to-1 args are constants but not constexprs so we don't need to pass these args as "constants".

What this PR supports:
* Triton versions before Dec 9, 2024, and (partial support for) Triton versions after Jan 1, 2025
* (triton jan 1+) typical inductor-generated triton: updated AttrsDescriptor, signatures, constexpr/constant handling.

What this PR doesn't support (TODO in follow-up PRs):
* Triton versions between Dec 9, 2024 and before Jan 1, 2025
* (triton jan 1+) user-defined triton kernel support (this is implemented already in @anmyachev's patch)
* (triton jan 1+) triton_helper support (failing in triton codegen - needs investigation)
* (triton jan 1+) AOTI / cpp wrapper

thanks to @anmyachev for patches in https://github.com/intel/intel-xpu-backend-for-triton/blob/main/scripts/pytorch.patch, which contains most of these changes already

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145051
Approved by: https://github.com/jansel
2025-01-24 00:34:01 +00:00
46851022ff [Inductor][CPU] Add auto-tuning support for da8w8 sym act sym wgt GEMM (#143187)
## Summary

Templated `int8xint8->int32` GEMM that uses AMX ISA (present on Intel Xeon Gen 4 & above). Any epilogues such as weight scale, activation scale, and bias are applied per output block in a fused manner .
Performs well for large values of `M` dimension (assuming canonical dimensions [`M, K`] and [`K, N`] for the activation & weight matrices'/tensors' sizes) when the activation is quantized per-token.
Also supports SmoothQuant GEMM pattern when activation is quantized per-tensor (scalar scale) or per-token (vector scale is applied as an epilogue in this case).

Also increased coverage of GEMM template for uint8 activation, int8 weight GEMM UTs for when the activation zero point is a 1D tensor (the existing implementation only accepted 0D tensors). However, some of such UTs would have to be explicitly enabled with `max-autotune` Inductor config.

## Performance data

The templated codegened fused GEMM with M=32, K=4096, N=14336 used in LLaMA3 exhibits more than 2x perf-gain compared to oneDNN qlinear + mul (for activation's scale) with 48 cores of one socket of Xeon SP 4th gen Platinum 8468 when per-token quantization is used.

For M=1, K=4096, N=14336, regardless of whether per-tensor quantization was used for activation or per-token, the perf gain was more than 3x.

Intel OpenMP & libtcmalloc had been preloaded. All cores used by the workload corresponded to distinct physical cores.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143187
Approved by: https://github.com/jansel, https://github.com/leslie-fang-intel, https://github.com/jgong5

Co-authored-by: Leslie Fang <leslie.fang@intel.com>
2025-01-22 02:27:53 +00:00
bac62341eb PEP585 update - torch/_inductor (#145198)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145198
Approved by: https://github.com/bobrenjc93
2025-01-21 21:04:33 +00:00
9d98b66e7b [Inductor][CPP] Enable Epilogue Fusion for Grouped GEMM Template (#143897)
**Summary**
In this PR, we enable the epilogues fusion and code generation for Grouped GEMM. Here are the high-level description of how we implement it.

**Fusion**

- The Grouped GEMM Template produces a `Template Buffer` with a `MultiOutputLayout` and a set of `MultiOutput Buffers`, where each buffer corresponds to a specific GEMM.
- During the initial round of fusion, the `Template Buffer` and all associated `MultiOutput Buffers` are fused into a `FusedSchedulerNode` by extending the existing fusion design.
- In subsequent fusion rounds, this `FusedSchedulerNode` can further fuse with its epilogues, following the original fusion design principles.

**Code Gen**
We maintain a list of epilogues and codegen it one by one.

- If any of the GEMM has bias, we create  a extra `bias_add` epilogue and prepend it at first of the epilogue list.
- If any of the GEMM has no epilogue, we create a `to_bf16` copy epilogue and append it at last of the epilogue list.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear_epilogue
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143897
Approved by: https://github.com/jansel, https://github.com/jgong5
ghstack dependencies: #143796
2025-01-14 06:07:50 +00:00
a3ab27b8e0 Migrate from Tuple -> tuple in torch/_inductor (#144264)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144264
Approved by: https://github.com/eellison
2025-01-07 03:27:27 +00:00
934eaa503f [Inductor XPU] Support max-autotune on XPU and reuse the corresponding Inductor UT. (#143266)
This PR aims to add the functionality support of max-autotune for XPU. The current triton templates and configurations are not well optimized for XPU, so the performance is not ready yet. Also the `mm_plus_mm` template have accuracy issues in some cases. We will address these issues in the next PRs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143266
Approved by: https://github.com/EikanWang, https://github.com/jansel
2024-12-30 23:51:17 +00:00
2da7fb5320 [inductor] Make generated kernels deterministic (#143951)
`"compile_id"` had slipped into our generated Triton code (in the
metadata), which will defeat caching because the same kernels generated
in a different order would not cache hit with eachother.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143951
Approved by: https://github.com/oulgen
2024-12-30 23:35:11 +00:00
1b0d19a2cb Revert "[inductor] Make generated kernels deterministic (#143951)"
This reverts commit 79b354ee37b7d8a06a48ca8cc4e19a3fd006b433.

Reverted https://github.com/pytorch/pytorch/pull/143951 on behalf of https://github.com/wdvr due to failing tests on trunk ([comment](https://github.com/pytorch/pytorch/pull/143951#issuecomment-2564952267))
2024-12-30 02:06:38 +00:00
79b354ee37 [inductor] Make generated kernels deterministic (#143951)
`"compile_id"` had slipped into our generated Triton code (in the
metadata), which will defeat caching because the same kernels generated
in a different order would not cache hit with eachother.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143951
Approved by: https://github.com/oulgen
2024-12-29 19:53:33 +00:00
844e6108f6 Revert "[Inductor XPU] Support max-autotune on XPU and reuse the corresponding Inductor UT. (#143266)"
This reverts commit ad750ae32079020f51f9b7d01237f3ecfa83b6ff.

Reverted https://github.com/pytorch/pytorch/pull/143266 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing some tests in trunk ([comment](https://github.com/pytorch/pytorch/pull/143266#issuecomment-2561303786))
2024-12-24 17:22:57 +00:00
ad750ae320 [Inductor XPU] Support max-autotune on XPU and reuse the corresponding Inductor UT. (#143266)
This PR aims to add the functionality support of max-autotune for XPU. The current triton templates and configurations are not well optimized for XPU, so the performance is not ready yet. Also the `mm_plus_mm` template have accuracy issues in some cases. We will address these issues in the next PRs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143266
Approved by: https://github.com/EikanWang, https://github.com/jansel
2024-12-24 05:42:36 +00:00
a316a4581d Add mps to GPU_TYPES (#143634)
Because it is a GPU, but don't require a triton, as it does not need one

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143634
Approved by: https://github.com/jansel
2024-12-22 18:37:35 +00:00
af0e159740 [Inductor XPU] Add XPU check for is_big_gpu(). (#143491)
Fix #143472

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143491
Approved by: https://github.com/desertfire, https://github.com/jansel, https://github.com/EikanWang
2024-12-21 02:27:04 +00:00
8960cb5809 Add support for bfloat16 atomic adds in fbcode (#143629)
Reland https://github.com/pytorch/pytorch/pull/141857 and fallback on A100 which doesn't have bfloat16 atomic add instrs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143629
Approved by: https://github.com/eellison
2024-12-20 23:05:13 +00:00
b4e0e3bfa3 Backout D66648013 (#143433)
Summary:
backing out https://www.internalfb.com/diff/D66648013 (see comments there for justification)

I will reland and disallow the bfloat16 atomics behavior on A100 because it causes a pretty significant performance regression.

Test Plan: This is a revert

Differential Revision: D67357485

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143433
Approved by: https://github.com/davidberard98
2024-12-19 00:53:49 +00:00
9275091d6e [provenance_tracking] Dump inductor_triton_kernel_to_post_grad_nodes.json info in debug_trace (#143055)
Summary:
This diff mainly adds code changes to dump `inductor_triton_kernel_to_post_grad_nodes.json` artifact which contains mapping info from post_grad -> inductor kernel code:
`{"inductor_triton_kernel_name": [post_grad_node_0, post_grad_node_1, ..., ], "..."}.`

Example paste: P1695235000 verified on the test model.  See "Test Plan":

We use this artifact to demonstrate provenance tracking in the frontend 3-tab highlighter tool:
https://github.com/YUNQIUGUO/compiler_explorer (copy/pasted the input files for demo purpose for now and will integrate with Shangdi's tool to 4-tab)

https://pxl.cl/66BzK

Note: Currently only supports mapping for inductor's`TritonKernel` type. TODO for enhancing more support for `ExternKernel` and other inductor generated kernel type, etc.

Test Plan:
test_model_coverage.sh:
```
#!/bin/sh
MODEL_ENTITY_ID=644688112
SNAPSHOT_ID=32
MODULE=merge

# buck2 build --show-output mode/opt -c=python.package_style=inplace -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.split-dwarf=true -c fbcode.nvcc_arch=a100,h100 caffe2/torch/fb/model_transform/experimental/benchmark:mts_gpu_benchmark

TORCH_COMPILE_DEBUG=1 CUDA_VISIBLE_DEVICES=0 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCH_LOGS="+inductor, schedule, fusion, output_code" TORCH_TRACE="tmp/guorachel_tt" TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 ../buck-out/v2/gen/fbcode/d29ee94b913014f1/caffe2/torch/fb/model_transform/experimental/benchmark/__mts_gpu_benchmark__/mts_gpu_benchmark.par --model-path manifold://ads_storage_fblearner/tree/user/facebook/fblearner/predictor/${MODEL_ENTITY_ID}/${SNAPSHOT_ID}/gpu_lowering/input.predictor.disagg.gpu.merge --lower-backend AOT_INDUCTOR_EP --gpu-trace --aot-inductor-config="{'max_autotune': True}" 2>&1 | tee output.txt
```
 {F1973765026}

```
buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:provenance_tracing -- --exact 'caffe2/test/inductor:provenance_tracing - test_triton_kernel_post_grad_mapping_aot_inductor (caffe2.test.inductor.test_provenance_tracing.TestProvenanceTracingArtifact)'
```

```
TORCH_LOGS="+inductor, output_code" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:provenance_tracing -- -r test_triton_kernel_post_grad_mapping_aot_inductor
```

Differential Revision: D66967510

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143055
Approved by: https://github.com/chenyang78
2024-12-18 06:51:50 +00:00
e885225eda Add persistent+TMA version of Triton mm and addmm (#142101)
This PR adds persistent+TMA versions (Triton template + the corresponding infra) for the `tuned_mm` and `tuned_addmm` lowerings. The persistent+TMA choices are added to the GEMM autotuning if (checked by the `use_triton_tma_template` helper):

1. The min. hardware and Triton version requirements are met for the TMA support.

2. The GEMM inputs are compatible with the Triton TMA API (i.e., 16-byte aligned and contiguous).

3. The `config.triton.enable_persistent_tma_matmul` is set to `True`.

Additional notes:

1. As added in this PR, the TMA uses are not compatible with prolog / epilogue fusion. To this end, in the new Triton template we currently support: TMA-based loads of A/B, but no prologue fusion; epilogue fusion, but no TMA-based stores of C. TMA + fusion compatibility can be added as a follow-up.

2. The current Triton TMA API (`experimental_device_tensormap_create2d`) does not support strides. Due to this, we limit the applicability of the new Triton template to the cases where the inputs are contiguous.

3. The transposed layouts of A and / or B are supported by passing the constexpr flags to the kernel and adjusting the ordering of the block sizes accordingly in the kernel code (this should have no effect on the kernel perf, as decided at the Triton compilation time).

4. After the next Triton pin update, we can switch to the tensor descriptor API (landed recently in https://github.com/triton-lang/triton/pull/5290) in the new Triton template, which should allow lifting 2 and 3 above.

5. The configs for the new Triton template in `persistent_mm_kernel_configs` are preliminary. We should do more perf exploration and possibly augment the config in a follow-up.

6. This PR is rebased onto and unifies with two related PRs landed previously: https://github.com/pytorch/pytorch/pull/142045 (some infra unification with the persistent+TMA template for _scaled_mm) and https://github.com/pytorch/pytorch/pull/134532 (add possibility to disable prolog fusion for selected choices).

7. The current Triton TMA API only supports 1D and 2D descriptors (even after https://github.com/triton-lang/triton/pull/5290, see [here](9829ce87cc/python/triton/language/core.py (L1957))). For now, this blocks adding persistent+TMA template for `torch.bmm`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142101
Approved by: https://github.com/drisspg, https://github.com/eellison
2024-12-16 19:12:12 +00:00
da67a6a7bb [inductor] Replace set by OrderedSet (#138466)
Uses the set_linter from https://github.com/pytorch/pytorch/pull/138454
and considerable manual editing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138466
Approved by: https://github.com/eellison
2024-12-13 16:08:45 +00:00
dc23f1944a Remove unused Python variables in torch/[_-a]* (#133492)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133492
Approved by: https://github.com/albanD
2024-12-12 17:39:14 +00:00
520ba556cd [Inductor] Refactor "r" reduction prefix to {"r0_", "r1_"}. (#142020)
Preparatory refactor for https://github.com/pytorch/pytorch/pull/137243.

# Feature

This PR changes the `RINDEX` / `"r"` symbol type to `(R0_INDEX, R1_INDEX)` and `("r0_", "r1_")`, respectively. This allows the relevant code to support 2D (often ND) reductions. Unlike the parent PR, this one does not change the tiling algorithm, so `"r1_"` is never used. However, it prepares other parts of the system to handle `"r1_"` once we start using it. This should significantly reduce the chances of hitting merge conflicts, making the parent PR much easier to land.

The only change to the generated triton code is to rename `"rindex"` -> `"r0_index"`, `"RBLOCK"` -> `"R0_BLOCK"`, etc. To maintain compatibilty with existing codegen, this also generates aliases to the old reduction variables like `rindex = r0_index`. If we generated 2D reductions (which this PR will not do), the aliases would be more complicated and would collapse 2D multi-indices to linear indices. See some example kernels in the parent PR.

These aliases can be eliminated by the Triton compiler, and should not impact the final machine code running on the GPU. See the perf testing in the parent PR which confirms the aliases do not impact perf.

# Test plan

The existing CI provides good coverage. This PR modifies the expected code in a few places, renaming reduction variables from `r.*` to `r0_.*`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142020
Approved by: https://github.com/jansel

Co-authored-by: Jason Ansel <jansel@meta.com>
2024-12-12 17:22:20 +00:00
5c97ac9721 Revert "Remove unused Python variables in torch/[_-a]* (#133492)"
This reverts commit fda975a7b3071a20dab8fc2c4e453479e1bb7cf2.

Reverted https://github.com/pytorch/pytorch/pull/133492 on behalf of https://github.com/clee2000 due to Sorry, I need to revert this in order to revert something else.  The only thing you need to do is rebase and remerge ([comment](https://github.com/pytorch/pytorch/pull/133492#issuecomment-2536635516))
2024-12-11 17:29:12 +00:00
ed388394d1 add torchrec collectives to enforce global ordering (#141970)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141970
Approved by: https://github.com/yf225
2024-12-11 02:45:24 +00:00
fda975a7b3 Remove unused Python variables in torch/[_-a]* (#133492)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133492
Approved by: https://github.com/albanD
2024-12-10 21:48:44 +00:00
a3abe1a5ae Add support for bfloat16 atomic adds in fbcode (#141857)
This adds support for bfloat16 atomic add in fbcode (OSS will have to wait until those changes are upstreamed to triton)

Originally I attempted to write inline asm, but the triton API was not flexible enough to support this use case. In the long run the right answer is to implement this properly in OSS triton.

relevant issues:
* https://github.com/pytorch/pytorch/issues/137425 in fbcode only
* https://github.com/pytorch/pytorch/issues/97016

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141857
Approved by: https://github.com/eellison
2024-12-10 11:40:15 +00:00
bcddae14ec Enhance "from_node" node meta to track source recursively (#142066)
Summary:
Change the "from_node" node meta format to be able to track the provenance of nodes recursively.

The new "from_node" format is a a list node NodeSource:

```
class NodeSource:
	self.node_name: str
	self.target: str
	self.graph_id: int
	self.pass_name: str
	self.action: str
	self.from_node: List[NoedSource]
```

This is in preparation for the inductor provenance tracking. For background, the inductor provenance tracking doc: https://docs.google.com/document/d/1dGh9myqNhywmbfP0Quzx_f04bghDFlj8cawj8MopiO8/edit?fbclid=IwZXh0bgNhZW0CMTEAAR0jUQ0Tf4ROLDED8Y_eIzrU0KVZVdRmyIQLp-avt-kGRPI_VgYVNyjH_q0_aem_HCQ_pxHDiwOkO9mQyWB2-g&tab=t.0 (internal only),

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_unflatten_multiple_graphs_state
buck run mode/dev-nosan caffe2/test:fx -- -r node_source
```

Differential Revision: D66737916

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142066
Approved by: https://github.com/avikchaudhuri
2024-12-09 23:39:15 +00:00
5c76a2834d Revert "add torchrec collectives to enforce global ordering (#141970)"
This reverts commit ceb94d6a7d38930d662e7eb71b9c7620de8c2997.

Reverted https://github.com/pytorch/pytorch/pull/141970 on behalf of https://github.com/malfet due to Apologies for reverting this change, but it broke MacOS testing, but CI was broken at the time ([comment](https://github.com/pytorch/pytorch/pull/141970#issuecomment-2529367680))
2024-12-09 20:25:04 +00:00
e343f46464 [inductor] Refactor is_big_gpu (#142220)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142220
Approved by: https://github.com/yanboliang
ghstack dependencies: #142219, #142033, #142222
2024-12-08 18:51:36 +00:00
2682e5e0d4 [BE]: Add TypeGuard to is_symbolic (#142304)
Improves type inference for is_symbolic. If it's True, it must be either a SymInt or Torch Tensor currently.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142304
Approved by: https://github.com/jansel
2024-12-08 02:18:17 +00:00
0367a31401 [inductor] Minor typing changes (#142219)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142219
Approved by: https://github.com/Skylion007, https://github.com/yanboliang
2024-12-07 17:48:37 +00:00
ceb94d6a7d add torchrec collectives to enforce global ordering (#141970)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141970
Approved by: https://github.com/yf225
2024-12-06 22:38:54 +00:00
20f24e3fbd [inductor][cpp] Add BMM kernel template for autotuning (#129772)
This PR adds the Cpp template for BMM, for FP32, FP16, and BF16. See #125683 for more background.

1.  Adds `CppBmmTemplate` class which inherits from `CppPackedGemmTemplate`. Given a number of worker threads `num_threads` and batch size `B`, execute the Gemm kernel. For the first `B - (B % num_threads)` batch inputs, run one sub-gemm problem per thread. Then for the remaining `B % num_threads` sub-gemms, we execute each subproblem using the parallelized Gemm kernel.
To manage this code, the `GEMM_TEMPLATE` from `CppPackedGemmTemplate` is rendered two different times, one with a single thread and one which includes the parallel OMP pragma.
2. Adapts `CppPackedGemmTemplate` to allow for child class. The `GEMM_TEMPLATE` is separated into different strings to allow for rendering by the child class. Slicing/indexing are adapted to allow for 3D BMM inputs. Additional methods `get_options()` and `_get_params_for_choices()` are added to reduce code duplication.

BMM within `dlrm` benchmark has a single input buffer which is used for but X and W inputs. This is currently not supported in this PR.

### Performance
On Granite/Sapphire Rapids, cpp_bmm template code uses AMX which requires an expensive transpose operation so the BMM op is rarely selected as faster than the existing external bmm kernel. As a result, speedup on SPR is identical with and without BMM code. Pass rate matches the rates for main exactly.

#### Test Summary on Granite Rapids
Test   Scenario | Comp Item | Date | Compiler | torchbench | huggingface | timm_models
-- | -- | -- | -- | -- | -- | --
Single Socket Multi-Threads | Pass   Rate | gemm autotune| inductor | 91%,   73/80 | 100%,   46/46 | 100%,   61/61
   |     |   |  bmm + gemm autotune | inductor | 91%,   73/80 | 100%,   46/46 | 100%,   61/61
  |  |  Geomean Speedup | gemm autotune| inductor | 2.15x | 1.91x | 2.52x
   |     |   |  bmm + gemm autotune | inductor | 2.15x | 1.96x | 2.53x
Single Core Single-Thread | Pass   Rate | gemm autotune | inductor | 91%,   73/80 | 100%,   46/46 | 100%,   61/61
   |    |   |  bmm + gemm autotune| inductor | 91%,   73/80 | 100%,   46/46 | 100%,   61/61
 |  | Geomean Speedup | inductor_locally_benchmark_586 | inductor | 2.43x | 1.56x | 2.60x
   |    |   |  inductor_locally_benchmark_585 | inductor | 2.45x | 1.56x | 2.63x

This is not the case on an older Skylake Xeon machine.
For the BMM ops contained in torchbench models, bmm performance improves by 1.10-2.64x.

#### BF16 28-core Skylake Xeon
| Model | Inductor | GemmAutotune | Gemm+BMM Autotune |
|--------|--------|--------|--------|
| BERT_pytorch | 1.233x | 2.597x | 2.608x |
| hf_DistilBert | 1.128x | 2.242x | 2.368x |
| hf_Reformer | 1.124x | 1.419x | 1.590x |
| hf_T5_base | 1.012x | 1.257x | 1.382x |
| hf_T5_large | 1.085x | 2.228x | 2.345x |

## Example BMM Code
```
#include <c10/util/Unroll.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>

template <bool accum>
inline void cpp_bmm_micro_gemm_amx_kernel_32_2(
    AMXState& amx_state,
    const bfloat16* __restrict__ A,
    const bfloat16* __restrict__ B,
    float* __restrict__ C,
    int64_t K,
    int64_t lda,
    int64_t ldb,
    int64_t ldc,
    uint8_t tilecfg_rows
) {
    // TODO(jgong5): add prefetch hint for A, B, C
    auto loadconfig = [](const amx_tilecfg& cfg) {
        _tile_loadconfig(&cfg);
    };
    const auto last_k_offset = K / 32 * 32;
    const auto tail_k_size = K - last_k_offset;
    if C10_LIKELY (last_k_offset > 0) {
        amx_state.configure(tilecfg_rows, 64, 32 / 16, 2, loadconfig);
    } else {
        amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig);
    }
    auto load_c = [&]() {
        _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float));
        _tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float));
        _tile_loadd(2, C + 16 * ldc + 0, ldc * sizeof(float));
        _tile_loadd(3, C + 16 * ldc + 16, ldc * sizeof(float));
    };
    auto zero_c = [&]() {
        _tile_zero(0);
        _tile_zero(1);
        _tile_zero(2);
        _tile_zero(3);
    };

    if constexpr (accum) {
        load_c();
    } else {
        zero_c();
    }

    auto compute = [&](int k) {
        _tile_stream_loadd(4, A + 0 * lda + k, lda * sizeof(bfloat16));
        _tile_loadd(6, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16));
        _tile_dpbf16ps(0, 4, 6);
        _tile_loadd(7, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16));
        _tile_dpbf16ps(1, 4, 7);
        _tile_stream_loadd(5, A + 16 * lda + k, lda * sizeof(bfloat16));
        _tile_dpbf16ps(2, 5, 6);
        _tile_dpbf16ps(3, 5, 7);
    };

    #pragma GCC unroll 4
    for (int k = 0; k < last_k_offset; k += 32) {
        compute(k);
    }

    auto store_c = [&]() {
    // store to C
        _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float));
        _tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float));
        _tile_stored(2, C + 16 * ldc + 0, ldc * sizeof(float));
        _tile_stored(3, C + 16 * ldc + 16, ldc * sizeof(float));
    };

    // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead
    if C10_UNLIKELY (tail_k_size > 0) {
        if C10_LIKELY (last_k_offset > 0) {
            store_c();
            amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig);
            load_c();
        }
        compute(last_k_offset);
    }

    store_c();
}

template <bool accum>
inline void cpp_bmm_micro_gemm_amx_kernel_16_2(
    AMXState& amx_state,
    const bfloat16* __restrict__ A,
    const bfloat16* __restrict__ B,
    float* __restrict__ C,
    int64_t K,
    int64_t lda,
    int64_t ldb,
    int64_t ldc,
    uint8_t tilecfg_rows
) {
    // TODO(jgong5): add prefetch hint for A, B, C
    auto loadconfig = [](const amx_tilecfg& cfg) {
        _tile_loadconfig(&cfg);
    };
    const auto last_k_offset = K / 32 * 32;
    const auto tail_k_size = K - last_k_offset;
    if C10_LIKELY (last_k_offset > 0) {
        amx_state.configure(tilecfg_rows, 64, 16 / 16, 2, loadconfig);
    } else {
        amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig);
    }
    auto load_c = [&]() {
        _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float));
        _tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float));
    };
    auto zero_c = [&]() {
        _tile_zero(0);
        _tile_zero(1);
    };

    if constexpr (accum) {
        load_c();
    } else {
        zero_c();
    }

    auto compute = [&](int k) {
        _tile_stream_loadd(2, A + 0 * lda + k, lda * sizeof(bfloat16));
        _tile_loadd(3, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16));
        _tile_dpbf16ps(0, 2, 3);
        _tile_loadd(4, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16));
        _tile_dpbf16ps(1, 2, 4);
    };

    #pragma GCC unroll 4
    for (int k = 0; k < last_k_offset; k += 32) {
        compute(k);
    }

    auto store_c = [&]() {
    // store to C
        _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float));
        _tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float));
    };

    // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead
    if C10_UNLIKELY (tail_k_size > 0) {
        if C10_LIKELY (last_k_offset > 0) {
            store_c();
            amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig);
            load_c();
        }
        compute(last_k_offset);
    }

    store_c();
}

template <bool accum>
inline void cpp_bmm_micro_gemm(
    AMXState& amx_state,
    const bfloat16* __restrict__ A,
    const bfloat16* __restrict__ B,
    float* __restrict__ C,
    int64_t M,
    int64_t N,
    int64_t K,
    int64_t lda,
    int64_t ldb,
    int64_t ldc
) {
    AOTI_TORCH_CHECK(N % 32 == 0, "N dimension must be multiple of 32");
    AOTI_TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2");
    // TODO(jgong5): loop unroll for M and N
    for (int64_t n = 0; n < N; n += 32) {
        for (int64_t m = 0; m < M; m += 32) {
            int64_t block_m = std::min<int64_t>(M - m, 32);
            int64_t m_tail = m;
            if (block_m >= 32) {
                cpp_bmm_micro_gemm_amx_kernel_32_2<accum>(
                    amx_state,
                    A + m * lda,
                    B + n,
                    C + m * ldc + n,
                    K,
                    lda,
                    ldb,
                    ldc,
                    16
                );
                block_m -= 32;
                m_tail += 32;
            }
            else
            if (block_m >= 16) {
                cpp_bmm_micro_gemm_amx_kernel_16_2<accum>(
                    amx_state,
                    A + m * lda,
                    B + n,
                    C + m * ldc + n,
                    K,
                    lda,
                    ldb,
                    ldc,
                    16
                );
                block_m -= 16;
                m_tail += 16;
            }
            if (block_m > 0) {
                cpp_bmm_micro_gemm_amx_kernel_16_2<accum>(
                    amx_state,
                    A + m_tail * lda,
                    B + n,
                    C + m_tail * ldc + n,
                    K,
                    lda,
                    ldb,
                    ldc,
                    block_m
                );
            }
        }
    }
}
void threaded_mm(const bfloat16* X, const bfloat16* W, bfloat16* Y, const int64_t ks_b_index)
{

    constexpr int64_t num_threads = 48;
    constexpr int64_t N = 64;
    constexpr int64_t K = 96;
    constexpr int64_t Mr = 32;
    constexpr int64_t Nr = 32;
    constexpr int64_t Kr = 32;
    constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
    constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
    constexpr int64_t M = static_cast<int64_t>(384L);
    constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
    constexpr int64_t Mt_blocks = 1;
    constexpr int64_t Nt_blocks = 1;
    constexpr int64_t Kt_blocks = 3;
    constexpr int64_t Mc_blocks = 1;
    constexpr int64_t Nc_blocks = 1;
    constexpr int64_t Kc_blocks = 3;
    constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
    constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
    constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks;
    constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
    constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;

    // make sure all partitions are assigned
    AOTI_TORCH_CHECK(
        Mt_blocks * Nt_blocks * Kt_blocks * 48 >= Mr_blocks * Nr_blocks * Kr_blocks,
        "Not all partitions are assigned."
    );
    #pragma omp parallel num_threads(48)
    {
        const int tid = omp_get_thread_num();
        const int64_t k_group_id = tid / num_Kt_blocks;
        const int64_t k_slice_id = tid % num_Kt_blocks;
        const int64_t n_group_id = k_group_id / num_Nt_blocks;
        const int64_t n_slice_id = k_group_id % num_Nt_blocks;
        const int64_t k_block_start = k_slice_id * Kt_blocks;
        const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks);
        const int64_t n_block_start = n_slice_id * Nt_blocks;
        const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks);
        const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks);
        const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks);
        const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks;
        AMXState amx_state;
        auto _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf = _local_acc_buf.get();
        for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
            const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread;
            const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks;
            const int64_t m_start = mc * Mr;
            const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
            const int64_t m_size = m_end - m_start;
            for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
                const int64_t n_start = nc * Nr;
                const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
                const int64_t n_size = n_end - n_start;
                // NB: assume we pad N, nc_block_end won't exceed padded N here.
                const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end);
                if (_local_acc_buf == nullptr) { _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf = _local_acc_buf.get(); }
                for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
                    int64_t k_start = kc * Kr;
                    int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
                    for (int64_t nci = nc; nci < nc_block_end; nci++) {
                        if (kc == k_block_start) {
                            cpp_bmm_micro_gemm<static_cast<bool>(false)>(
                                amx_state,
                                &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]),
                                &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]),
                                &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]),
                                static_cast<int64_t>(m_end + ((-1L)*m_start)),
                                static_cast<int64_t>(Nr),
                                static_cast<int64_t>(k_end + ((-1L)*k_start)),
                                static_cast<int64_t>(96L),
                                static_cast<int64_t>(32L),
                                static_cast<int64_t>(Nc_blocks*Nr)
                            );

                        } else {
                            cpp_bmm_micro_gemm<static_cast<bool>(true)>(
                                amx_state,
                                &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]),
                                &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]),
                                &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]),
                                static_cast<int64_t>(m_end + ((-1L)*m_start)),
                                static_cast<int64_t>(Nr),
                                static_cast<int64_t>(k_end + ((-1L)*k_start)),
                                static_cast<int64_t>(96L),
                                static_cast<int64_t>(32L),
                                static_cast<int64_t>(Nc_blocks*Nr)
                            );

                        }
                    }
                }
                {
                    {
                        #pragma GCC ivdep
                        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L))
                        {
                            for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16));
                                auto tmp1 = at::vec::convert<bfloat16>(tmp0);
                                tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(16));
                            }
                            for(int64_t x1=static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=(static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))) == 0 ? 1 : static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))));
                                auto tmp1 = at::vec::convert<bfloat16>(tmp0);
                                tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))));
                            }
                        }
                    }

                }
            }
        }
        amx_state.release([]() { _tile_release(); });
    }
}
void single_thread_mm(const bfloat16* X, const bfloat16* W, bfloat16* Y, const int64_t ks_b_index)
{

    constexpr int64_t num_threads = 1;
    constexpr int64_t N = 64;
    constexpr int64_t K = 96;
    constexpr int64_t Mr = 32;
    constexpr int64_t Nr = 32;
    constexpr int64_t Kr = 32;
    constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
    constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
    constexpr int64_t M = static_cast<int64_t>(384L);
    constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
    constexpr int64_t Mt_blocks = 12;
    constexpr int64_t Nt_blocks = 2;
    constexpr int64_t Kt_blocks = 3;
    constexpr int64_t Mc_blocks = 12;
    constexpr int64_t Nc_blocks = 1;
    constexpr int64_t Kc_blocks = 3;
    constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
    constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
    constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks;
    constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
    constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;

    // make sure all partitions are assigned
    AOTI_TORCH_CHECK(
        Mt_blocks * Nt_blocks * Kt_blocks * 1 >= Mr_blocks * Nr_blocks * Kr_blocks,
        "Not all partitions are assigned."
    );
    {
        constexpr int tid = 0;
        constexpr int64_t k_group_id = 0;
        constexpr int64_t k_slice_id = 0;
        constexpr int64_t n_group_id = 0;
        constexpr int64_t n_slice_id = 0;
        constexpr int64_t m_block_start = 0;
        constexpr int64_t n_block_start = 0;
        constexpr int64_t n_block_end = Nr_blocks;
        constexpr int64_t k_block_start = 0;
        constexpr int64_t k_block_end = Kr_blocks;
        constexpr int64_t num_Mc_blocks_per_thread = num_Mc_blocks;
        constexpr int64_t m_block_end = Mr_blocks;
        AMXState amx_state;
        auto _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf = _local_acc_buf.get();
        for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
            const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread;
            const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks;
            const int64_t m_start = mc * Mr;
            const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
            const int64_t m_size = m_end - m_start;
            for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
                const int64_t n_start = nc * Nr;
                const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
                const int64_t n_size = n_end - n_start;
                // NB: assume we pad N, nc_block_end won't exceed padded N here.
                const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end);
                if (_local_acc_buf == nullptr) { _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf = _local_acc_buf.get(); }
                for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
                    int64_t k_start = kc * Kr;
                    int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
                    for (int64_t nci = nc; nci < nc_block_end; nci++) {
                        if (kc == k_block_start) {
                            cpp_bmm_micro_gemm<static_cast<bool>(false)>(
                                amx_state,
                                &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]),
                                &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]),
                                &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]),
                                static_cast<int64_t>(m_end + ((-1L)*m_start)),
                                static_cast<int64_t>(Nr),
                                static_cast<int64_t>(k_end + ((-1L)*k_start)),
                                static_cast<int64_t>(96L),
                                static_cast<int64_t>(32L),
                                static_cast<int64_t>(Nc_blocks*Nr)
                            );

                        } else {
                            cpp_bmm_micro_gemm<static_cast<bool>(true)>(
                                amx_state,
                                &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]),
                                &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]),
                                &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]),
                                static_cast<int64_t>(m_end + ((-1L)*m_start)),
                                static_cast<int64_t>(Nr),
                                static_cast<int64_t>(k_end + ((-1L)*k_start)),
                                static_cast<int64_t>(96L),
                                static_cast<int64_t>(32L),
                                static_cast<int64_t>(Nc_blocks*Nr)
                            );

                        }
                    }
                }
                {
                    {
                        #pragma GCC ivdep
                        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L))
                        {
                            for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16));
                                auto tmp1 = at::vec::convert<bfloat16>(tmp0);
                                tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(16));
                            }
                            for(int64_t x1=static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=(static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))) == 0 ? 1 : static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))));
                                auto tmp1 = at::vec::convert<bfloat16>(tmp0);
                                tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))));
                            }
                        }
                    }

                }
            }
        }
        amx_state.release([]() { _tile_release(); });
    }
}
extern "C"
void cpp_bmm(const bfloat16* X, const bfloat16* W, bfloat16* Y)
{
    const int64_t B = static_cast<int64_t>(5L);
    constexpr int64_t num_threads = 48;
    int64_t B_single_thread_block = (B / num_threads) * num_threads;

    #pragma omp parallel for num_threads(48)
    for (int64_t b_start = 0; b_start < B_single_thread_block; ++b_start) {
        single_thread_mm(X, W, Y, b_start);
    }
    for (int64_t b_start = B_single_thread_block; b_start < B; ++b_start) {
        threaded_mm(X, W, Y, b_start);
    }
}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129772
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
2024-12-06 04:54:00 +00:00
471017cbc9 avoid specializing strides with DDPOptimizer + inductor (#140751)
Fixes https://github.com/pytorch/pytorch/issues/140229

Fixes https://github.com/pytorch/pytorch/issues/139474

The issue was that:

(1) DDPOptimizer has some logic to partition the dynamo graph into buckets, and run AOTAutograd/inductor on each bucket

(2) doing so requires knowing the **exact** strides of the outputs of each subgraph, so we can have example inputs (with correct strides) to each of the later subgraphs to compile with

(3) there is some existing logic to do this today: we have a `fakify_first_call` flag in AOTAutograd that lets you run it with fake tensor inputs (to handle the calling convention changes that AOTAutograd performs at runtime). During this process, we query inductor for the output strides that it compiled with

(4) these outputs strides are stored in the FX graph cache as raw strings of sympy expressions. We have a function, `evaluate_symexpr`, which given the sympy string, and the ShapeEnv's `var_to_val` mapping, will evaluate the sympy string to generate concrete strides

(5) evaluating this expression will specialize on the exact values of any variables in our shape env, however. In DDPOptimizer, we want to know what inductor's stride outputs are symbolically. This requires converting the (string) sympy expression into actual `SymInts` that we can return.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140751
Approved by: https://github.com/eellison
2024-12-05 03:41:12 +00:00
fd35be2fd3 TritonTemplate dtype fixes (#141991)
- Set the dtype of "acc" appropriately so that epilogue fusion will have args with dtype
- Update dtype propagation to use `type_to_dtype` instead of instantiating tensor
- Throw if we have a string arg where we should have a proper CSEVariable, unless we're doing the Modification Subgraph thing which is nyi. everything else is appropriately typed (cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @drisspg ).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141991
Approved by: https://github.com/drisspg
ghstack dependencies: #139945, #140057, #141495, #141882
2024-12-04 17:24:23 +00:00
5c2584a14c [ROCm] Enable inductor GEMM lowering for gfx11 (#141687)
This check doesn't make sense for some of the AMD gpus since they have the right amount of CUs but multi_processor_count returns WGPs on RDNA while still performing adequately. A lot of tests fail on modern archs due to this check defaulting them to not using the GEMMs backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141687
Approved by: https://github.com/pruthvistony, https://github.com/jeffdaily, https://github.com/malfet

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2024-12-02 22:13:34 +00:00
f83361b274 inductor dtype propagation fixes (#141495)
- Add in upcast_compute_type on creation of new tensors (loads, constants)
- Fixes index_expr - right now we are sort of inconsistent in dtype and dont always respect the dtype specified. would be nice to fix but not doing in this pr.
- bug fix in view dtype where we were always upcasting back to fp32 when input was in bf16/fp16. we should only be doing that if the output is also in bf16/fp16.
- for masked, avoid calling dtype propagation and just use output dtype.

Turns on the runtime dtype verification for opinfo tests. The separate test file is still useful because we can use it for testing turning off codegen_upcast_to_fp32.

Follow ups:

- We could consider requiring less explicit upcast_compute_types calls and do it automatically. That would potentially make things easier but be less flexible in the future. Maybe I should have done it this pr.
- Be more consistent on our index expr dtype printing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141495
Approved by: https://github.com/blaine-rister, https://github.com/arui-meta, https://github.com/ezyang
ghstack dependencies: #139945, #140057
2024-11-28 11:39:38 +00:00
dbbebee9d7 Code motion CompiledFxGraph to a dedicated file (#141654)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141654
Approved by: https://github.com/aorenste, https://github.com/jansel
ghstack dependencies: #141491, #141492, #141574
2024-11-27 20:42:21 +00:00