Summary:
Moves the function used to load CuTeDSL Jinja templates up one level out of the flex attention folder. This way it can be used for more generate Inductor templates in the future.
Test Plan: `INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 TORCHINDUCTOR_CACHE_DIR=~/cutetest buck2 run mode/opt //caffe2/test/inductor:cutedsl_grouped_mm -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8`
Reviewed By: drisspg
Differential Revision: D84527470
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165576
Approved by: https://github.com/jananisriram
Summary: Moves the function used to load CuTeDSL Jinja templates up one level out of the flex attention folder. This way it can be used for more generate Inductor templates in the future.
Test Plan: `INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 TORCHINDUCTOR_CACHE_DIR=~/cutetest buck2 run mode/opt //caffe2/test/inductor:flex_flash -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8`
Differential Revision: D84527470
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165347
Approved by: https://github.com/drisspg
Verify the deterministic mode with torch.compile benchmark scripts.
Here is what my testing script does (pasted in the end):
- run a model in default mode, save it's result
- run the model again in default mode, but distort the benchmarking results. Compare it with the saved result.
- Do the above again in deterministic mode.
I tried to test a few modes
- BertForMaskedLM and GoogleFnet: I can repro the numeric change by distorting the benchnmark result in the default mode. The non-determinism is gone in the deterministic mode
- DistillGPT2: I can not repro the numeric change by distorting the benchmarking result in the default mode. It does not surprise me much. Reduction order change does not always cause numeric change.
```
model=GoogleFnet
export TORCHINDUCTOR_WRITE_ARE_DETERMINISTIC_ALGORITHMS_ENABLED=0
export TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 # disable autotune cache
export TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE=0
export TORCHINDUCTOR_FX_GRAPH_CACHE=0
export TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_shunting/
export TORCHINDUCTOR_BENCHMARK_KERNEL=1
export TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1
export INDUCTOR_TEST_DISABLE_FRESH_CACHE=1
# Non deterministic mode
# --float32 rather than --amp to make it easier to repro non-deterministic
echo "Save results for non-deterministic mode"
python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-non-deterministic.pkl
echo "Compare results with distorted benchmarking in non-deterministic mode"
TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-non-deterministic.pkl
echo "Save results for deterministic mode"
TORCHINDUCTOR_DETERMINISTIC=1 python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-deterministic.pkl
echo "Compare results with distorted benchmarking in deterministic mode"
TORCHINDUCTOR_DETERMINISTIC=1 TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-deterministic.pkl
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164904
Approved by: https://github.com/jansel, https://github.com/v0i0
This is follow-up of #165037. It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165142
Approved by: https://github.com/albanD
As title
In windows, we cannot modify the .dll to append weights at the end, the windows .dll loader will complain it's not a valid .dll file. So we store the weight blob as a separete file.
1. We add the following API which allows passing in a pointer to the weight blob and get the size of the weight blob.
```cpp
AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantsBlobSize(
AOTInductorModelContainerHandle container_handle,
uint64_t* ret_size);
// Load weights from a single blob in weight_blob_ptr
AOTI_API AOTIRuntimeError AOTInductorModelUpdateConstantsFromBlob(
AOTInductorModelContainerHandle container_handle,
const uint8_t* weight_blob_ptr);
```
2. We also add a method in ModelContainerRunner to load the weight:
If the runner see that there is a `.blob` file in the package, if will mmap the .blob file and use the content to load the constants.
3. We also add the `USE_MMAP_EXTERNAL` macro. When this macro is defined, the model expects to load the weights from external mmap'd weights.
Test Plan:
```
buck run @mode/dev-nosan caffe2/test/inductor:test_aot_inductor -- -r test_large_mmaped_weights_on_disk
```
Also tested for windows-cross compilation with 6542566585/demo/main_voxtral.cpp
```
Loaded model.dll
audio_encoder loaded
C:\Users\shangdiy\source\repos\torchnative\demo\token_embedding\data\aotinductor\model\model.wrapper.so
Loaded model.dll
token_embedding loaded
C:\Users\shangdiy\source\repos\torchnative\demo\text_decoder\data\aotinductor\model\model.wrapper.so
Loaded model.dll
Loading weights from C:\Users\shangdiy\source\repos\torchnative\demo\text_decoder\data\aotinductor\model\model.wrapper_weights.blob
text_decoder loaded
Load latency (ms):
audio_encoder: 1011.234
archive extraction: 0.000
.so loading: 1011.197
token_embedding: 525.773
archive extraction: 0.000
.so loading: 525.704
text_decoder: 3324.130
archive extraction: 0.000
.so loading: 3323.979
Run latency (ms):
audio_encoder: 285.958
audio_encoder output: dtype=bfloat16, shape=[1, 1125, 3072], numel=3456000
token_embedding: 6.676
token_embedding output: dtype=bfloat16, shape=[1, 1138, 3072], numel=3495936
text_decoder: 576.519
text_decoder output: dtype=bfloat16, shape=[1, 1138, 131072], numel=149159936
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164526
Approved by: https://github.com/desertfire
Verify the deterministic mode with torch.compile benchmark scripts.
Here is what my testing script does (pasted in the end):
- run a model in default mode, save it's result
- run the model again in default mode, but distort the benchmarking results. Compare it with the saved result.
- Do the above again in deterministic mode.
I tried to test a few modes
- BertForMaskedLM and GoogleFnet: I can repro the numeric change by distorting the benchnmark result in the default mode. The non-determinism is gone in the deterministic mode
- DistillGPT2: I can not repro the numeric change by distorting the benchmarking result in the default mode. It does not surprise me much. Reduction order change does not always cause numeric change.
```
model=GoogleFnet
export TORCHINDUCTOR_WRITE_ARE_DETERMINISTIC_ALGORITHMS_ENABLED=0
export TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 # disable autotune cache
export TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE=0
export TORCHINDUCTOR_FX_GRAPH_CACHE=0
export TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_shunting/
export TORCHINDUCTOR_BENCHMARK_KERNEL=1
export TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1
export INDUCTOR_TEST_DISABLE_FRESH_CACHE=1
# Non deterministic mode
# --float32 rather than --amp to make it easier to repro non-deterministic
echo "Save results for non-deterministic mode"
python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-non-deterministic.pkl
echo "Compare results with distorted benchmarking in non-deterministic mode"
TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-non-deterministic.pkl
echo "Save results for deterministic mode"
TORCHINDUCTOR_DETERMINISTIC=1 python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-deterministic.pkl
echo "Compare results with distorted benchmarking in deterministic mode"
TORCHINDUCTOR_DETERMINISTIC=1 TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-deterministic.pkl
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164904
Approved by: https://github.com/jansel, https://github.com/v0i0
ghstack dependencies: #164801, #164532
To run this, you need to install `mingw64-gcc-c++` and download windows cuda library toolkit.
See design doc and demo instructions in https://docs.google.com/document/d/1iDaChqA5nNKkBFTzsdkmoomvQlXHbnlb1Z4yEp7xaJA/edit?tab=t.0
If cross_platform_target is windows, we do the following:
- do not link to `sleef`. This can be improved in the future if we need it. Currently I avoid it because that requires extra setup on the linux side
- Use `mingw64-gcc-c++` to compile
- Use `WINDOWS_CUDA_HOME` instead of `CUDA_HOME` when linking to cuda
```
python test/inductor/test_aot_inductor_windows.py -k so
```
Other changes:
- de-couples compile_standalone config and dynamic link flag
- create a new aot_inductor_mode config module, which is used to control configs in aot_inductor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163188
Approved by: https://github.com/desertfire
Summary:
Triton templates tend to perform very poorly on large K, hence the introduction of decompose_k. As a result, when decompose_k is selected will disable exploring the Triton templates. We may want to consider an override in the future.
Note: Based on the timing results it may be desirable to better refine/prune the decompose k decisions.
Testing:
Tested by looking at the autotune/compilation time using a single shape in TritonBench.
`TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 python run --op gemm --rep 1000 --sleep 1.0 --m 512 --n 512 --k 300000 --only pt2_matmul_maxautotune`
Before this change:
`SingleProcess AUTOTUNE benchmarking takes 13.5368 seconds and 0.1595 seconds precompiling for 38 choices`
With this change:
`SingleProcess AUTOTUNE benchmarking takes 9.9626 seconds and 0.0020 seconds precompiling for 11 choices`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163781
Approved by: https://github.com/eellison, https://github.com/PaulZhang12
# Feature
Support `torch.cond` in the FX converter. The generated FX IR is conceptually indentical to what would come from `torch.export`:
- Submodules as stored as attributes, and accessed via `getattr`.
- The conditional is represented as `torch.ops.higher_order.cond`, which takes in the subgraphs, a predicate and submodule inputs.
# Implementation overview
The FX backend generates code for subgraphs using the following steps:
1. When `codegen_conditional` is called in `WrapperFxCodegen`, we emit a `ConditionalLine`.
a. We also codegen the true/false subgraphs at this time, storing their subgms for later.
2. At the beginning of FX conversion, generate `get_attr` nodes accessing each subgraph. It's important to do this at the start, before registering the node metadata hook. This also matches the convention followed by torch.export.
3. When we see the `ConditionalLine` in the FX converter, we generate a corresponding `torch.ops.higher_order.cond`.
# Implementation details
This ended up being a substantial change, as wrapper codegen has some special logic for subgraphs.
Certain methods of `PythonWrapperCodegen` are overridden by `SubgraphPythonWrapperCodegen`. To apply these overrides, we use multiple inheritance with the registered subclass of `WrapperFxCodegen`.
Unlike most other wrapper codegen methods, which map 1:1 to Wrapper IR lines, subgraph codegen generates a number of wrapper lines including `EnterSubgraphLine` and `ExitSubgraphLine`, along with Python or C++ code calling the subgraph as a function. These lines are used for some backends' memory planning.
In contrast, FX IR typically represents a subgraph call as a single HOP node, or a `call_module` op. To account for this difference, this PR introduces a new wrapper IR line called `ConditionalLine`, which is only used by the FX backend. We override the `codegen_conditional` method to emit this line. This sidesteps having to port the existing subgraph codegen and associated memory planning to Wrapper IR. (In principle, it seems possible to adapt the existing backends to `ConditionalLine`, but it could be a larger refactor, since we'd also have to update the memory planning.)
Some of the lower-level subgraph codegen methods are still shared between the FX and Python backends, such as `generate_subgraph_common`. Those were easier to port to Wrapper IR.
This also required generalizing the way the FX converter handles graph inputs and outputs. Previously, it assumed the IO signature was the same as `V.graph.module`, but this is only true for the parent graph, and not subgraphs. Instead, we need to call `get_graph_inputs` and `get_graph_outputs` to populate the inputs and outputs for subgraphs.
# Test plan
This PR adds a couple of tests using torch.cond. Here's an example graph generated by one of them:
```
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
%false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
%cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%arg0_1, %true_graph_0, %false_graph_0, (%arg1_1,)), kwargs = {})
%buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 6, constant_args_idx: 6, grid: [(1, 1, 1)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 6, XBLOCK: 8}})
return buf1
```
It also removes an existing negative test which checked that a certain error was raised when subgraphs were encountered.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163234
Approved by: https://github.com/angelayi, https://github.com/jansel
Summary:
I am really skeptical about inductor sizevars creating an empty shape env when not provided with one
i think we should fail there if the graph has dynamic shapes and no shape env is provided.
however i wonder if there are actually use cases that depends on the shape env not being there?
Reasoning APIs depends on facts in the shape env. and assumes some stuff exists for specific symbols.
Test Plan:
Fix the bug reported in creating simple e2e unit test is not trivial
https://www.internalfb.com/diff/D82337184
Rollback Plan:
Differential Revision: D82412384
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162927
Approved by: https://github.com/ezyang, https://github.com/eellison, https://github.com/jansel
Entering a device context takes 30 us and exiting a device context takes 11 us. If all graph partitions and cudagraph-unsafe ops happen on the same device, we can share the device context.
## Trace
Use vLLM as an example. The first trace shows dynamo graph partition.
<img width="1338" height="453" alt="image" src="https://github.com/user-attachments/assets/b81815fd-cdcb-4024-846a-5b64164f8bac" />
The second trace shows inductor graph partition prior to this PR.
<img width="1331" height="270" alt="image" src="https://github.com/user-attachments/assets/8d98b127-2053-4eae-9a31-5491661f14d8" />
Comparing with fx graph partition, we can see inductor graph partition shows extra overhead from enter/exit device contexts (13+6 us -> 30+11 us), but smaller runtime overhead (13 us -> 7 us). This motivates the PR to share default device context.
The third trace shows Inductor graph partition after this PR. We observe that the extra overhead from enter/exit device contexts have been fixed. At the same time, we observe the smaller runtime overhead.
<img width="1336" height="276" alt="image" src="https://github.com/user-attachments/assets/77be2237-34dd-4bac-ad9c-d9af3be36417" />
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162873
Approved by: https://github.com/shunting314
Summary:
This adds the Triton Tutorial Matmul persistent matmul with device side TMA for Blackwell and adds it as a template option for blackwell. This uses newer Triton features such as automatic warp specialization and loop flattening, which while still containing flaws can improve performance on blackwell. This does not include the Epilogue subtiling section, as that will be a followup PR.
This PR doesn't include any tuning. I am doing a larger benchmarking run to determine the best initial configs for tuning and will open a followup PR with better defaults soon.
Test Plan:
Tested on a Blackwell machine with test_max_autotune.py and confirmed the new tests pass.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162916
Approved by: https://github.com/NikhilAPatel
Summary:
Adds support for TMA store in all TMA matmul templates (notably persistent_tma including addmm and scaled_mm). This works by requiring a template be registered with `tma_store=True` and when met constructs indices/range_trees to hook into the existing code base's TMA store support.
This also includes a couple notable changes:
- Adds support in the TMA template support for checking the output layout.
- Adds support for "hoisting" the tensor descriptor to the top of the kernel. This will currently only be used by template code right now, but in principle it can be generalized to other implementation.
- Supports considering multiple indices as the "contiguous" index. This is handled with support for transposing the input data when the alignment is no longer consistent. In general since the TMA support is derived from the index it doesn't seems reasonable that the 1D index math forces a certain alignment depending on index ordering so long as the layout matches.
Test Plan:
Tested with test_max_autotune.py unit tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160480
Approved by: https://github.com/NikhilAPatel
Previously, DeviceInfo provided theoretical hardware information based on a hardcoded list manually created from various datasheets.
This update:
- Attempting to gather the information from a hardware library like `pynvml`, improving accuracy and expanding support to devices that don't have entries in the datasheet list.
- Adjusts flops and bw calculation based on these hardware values. For example, if the the memory or SMs are underclocked, it adjusts the theoretical max flops/bw accordingly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162245
Approved by: https://github.com/v0i0, https://github.com/shunting314
Reopened from #158747 which got reverted since without setuptools-scm in pytorch index URL the wheel cannot be built
We reconsider the original PR idea of introducing CK as a pytorch dependency on ROCm Linux and install the CK python package in CI only -- since (1) rocm-composable-kernel depends on setuptools-scm which depends on tomli and the existing index URLs need to be modified to host the new packages and (2) there also is a packaging [bug](https://github.com/pypa/setuptools/issues/3269#issuecomment-1254507377) in Ubuntu 22.04 which prevents correct dynamic version calculation with default system pip.
Extras:
-> this PR reconsiders how TORCHINDUCTOR_CK_DIR env variable is used; previously, this var was used to point to rocm-composable-kernel package installation path on the filesystem; now, the path is inferred by trying to import ck4inductor
-> the tests are updated to reflect this change
-> since in CI clang points to a bash script which invokes sccache, we cannot patch PATH to not contain sccache, this logic is removed from the testing code
-> scaled_mm test crashes during the benchmarking when the benchmarking happens in the main process, and times out benchmarking when it happens in a subprocess, on gfx942, so it is disabled
TBD: roll back rocm-mi300 workflow before merging
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162288
Approved by: https://github.com/jeffdaily
# why
- heuristics providers know decide whether to (or which choices to add)
in the max-autotune case
- enables an eventual override point to gracefully fallback to the
standard behavior
# what
- max-autotune is determined inside V.choices.get_mm_configs
because it's mm only right now, we can just do
`config.max_autotune or config.max_autotune_gemm`
a TODO indicates that this can change in the future when this
expands to more templates
# testing
```
python3 -bb -m pytest test/inductor/test_max_autotune.py -v
```
Differential Revision: [D81520573](https://our.internmc.facebook.com/intern/diff/D81520573)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161344
Approved by: https://github.com/jansel
ghstack dependencies: #162075, #161340, #161341, #161342, #161343
## Summary
Adds a subgraph decomposition for addmm and mm that performs well on large `K` compared to `M` and `N`, and functions well as an alternative to `split-k` on AMD (transposed only), which does not support AMD currently.
## Background
On AMD (MI300x), for a matmul A * B, if B is non-contiguous, the resulting matmul is quite a bit slower.
For example:
```
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[1, 178176]))
))
```
is a lot slower than:
```
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[6144, 1]))
))
```
This PR adds a subgraph decomposition to test out whether making B contiguous is faster than just using the normal kernels.
## Data
I ran this on unique non-contiguous shapes from torchbench/huggingface and got these speedups:
```
Parsed 420 unique shapes from benchmark output
addmm improvements when best:
addmm_16448x512x2048: +0.14%
addmm_128x2048x2048: +0.01%
addmm_128x768x1000: +0.75%
addmm_12672x3072x768: +1.08%
addmm_512x768x32000: +0.62%
addmm_12608x384x384: +0.00%
addmm_4160x1024x4096: +0.90%
addmm_16x768x2: +0.56%
addmm_12608x3072x768: +0.09%
addmm_64x4096x1000: +2.77%
addmm_256x1024x512: +1.99%
addmm_30x256x256: +1.12%
addmm_100480x128x384: +0.91%
addmm_6400x2048x512: +0.25%
addmm_61568x1024x256: +0.08%
addmm_1x768x768: +0.93%
addmm_12544x384x384: +0.19%
addmm_128x512x1000: +0.77%
addmm_2048x128x128: +1.32%
addmm_128x3072x1000: +0.24%
addmm_7936x512x2048: +0.07%
addmm_8192x512x2048: +0.33%
addmm_64x1024x1000: +1.43%
addmm_128x2304x1000: +0.01%
addmm_32768x256x2: +0.75%
addmm_64x384x1152: +0.79%
addmm_64x640x1000: +0.01%
addmm_100480x128x128: +0.87%
addmm_1152x3072x768: +1.13%
addmm_8192x256x2048: +1.40%
addmm_4096x128x768: +0.01%
addmm_128x2560x1000: +0.01%
addmm_12544x2048x512: +0.43%
addmm_200704x24x96: +0.14%
addmm_8448x512x2048: +0.96%
addmm_50176x256x1024: +0.62%
addmm_4160x4096x1024: +0.22%
addmm_4096x768x768: +0.32%
addmm_220x2048x512: +0.56%
addmm_8x2048x1000: +1.12%
addmm_256x197951x512: +26.99%
addmm_401536x64x192: +0.60%
addmm_2040x2048x512: +0.47%
addmm_512x1024x256: +1.32%
addmm_128x4096x1000: +1.67%
addmm_12672x768x768: +0.34%
addmm_128x368x1000: +0.77%
addmm_96x1280x1000: +0.01%
addmm_12544x512x2048: +0.41%
addmm_6272x320x1280: +0.76%
addmm_12544x3072x768: +0.09%
addmm_64x384x1000: +0.39%
mm improvements when best:
mm_200704x128x512: +1.29%
mm_663552x16x16: +0.80%
mm_4096x768x768: +0.51%
mm_131072x64x31: +0.24%
mm_12544x1152x384: +0.11%
mm_128x2048x2: +0.46%
mm_262144x16x23: +0.62%
mm_50176x576x192: +0.37%
mm_131072x16x31: +0.26%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================
Operation: addmm
----------------------------------------
Total shapes analyzed: 247
Average Subgraph placement: 3.38
Median Subgraph placement: 2.0
Subgraph is best choice: 52/247 shapes (21.1%)
Average improvement when best: 1.15%
Median improvement when best: 0.58%
Largest improvement when best: +26.99%
Operation: bmm
----------------------------------------
Total shapes analyzed: 85
Average Subgraph placement: 24.00
Median Subgraph placement: 21.0
Subgraph is best choice: 0/85 shapes (0.0%)
Average improvement when best: N/A (never best)
Median improvement when best: N/A (never best)
Largest improvement when best: N/A (never best)
Operation: mm
----------------------------------------
Total shapes analyzed: 88
Average Subgraph placement: 15.08
Median Subgraph placement: 4.0
Subgraph is best choice: 9/88 shapes (10.2%)
Average improvement when best: 0.52%
Median improvement when best: 0.46%
Largest improvement when best: +1.29%
```
## Results
The largest shape gain, `256,197951,512`, seemed to be driven by a case where the extern kernel is way faster than the best triton configs on the recursive autotune:
```
addmm,Extern,extern_kernels.addmm,256,197951,512,0.38024500012397766
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.005444049835205
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.04189395904541
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.1911399364471436
addmm,Triton,256,197951,512,64,128,32,2,4,8,2.496040105819702
addmm,Triton,256,197951,512,64,128,64,2,8,16,2.9306790828704834
addmm,Triton,256,197951,512,64,64,32,2,4,8,3.0347819328308105
...
```
Compared to the non-transposed autotune:
```
addmm,Subgraph,contiguous_addmm_1384,256,197951,512,0.5024129748344421
addmm,Extern,extern_kernels.addmm,256,197951,512,0.6881489753723145
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.5115010738372803
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.5167479515075684
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.9507460594177246
addmm,Triton,256,197951,512,64,256,64,2,8,4,2.9673290252685547
addmm,Triton,256,197951,512,64,128,64,2,8,16,3.3906331062316895
addmm,Triton,256,197951,512,64,128,32,2,4,8,3.496859073638916
```
It seems to perform really well for high values of `K` vs `N` and `M`.
Testing this hypothesis with some custom shapes:
```
Parsed 64 unique shapes from benchmark output
addmm improvements when best:
addmm_128x16384x128: +0.18%
addmm_128x262144x256: +38.24%
addmm_128x200000x512: +14.76%
addmm_256x800000x128: +0.06%
addmm_131072x128x256: +0.27%
addmm_128x256x131072: +0.25%
addmm_2048x200000x64: +12.45%
mm improvements when best:
mm_128x16384x128: +0.18%
mm_128x262144x256: +38.05%
mm_128x200000x512: +9.47%
mm_256x800000x128: +0.99%
mm_512x6400000x256: +3.17%
mm_524288x64x64: +0.29%
mm_2048x200000x64: +11.19%
mm_8192x1000000x256: +34.14%
mm_128x4096x100000: +0.40%
mm_128x3072x150000: +0.27%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================
Operation: addmm
----------------------------------------
Total shapes analyzed: 33
Average Subgraph placement: 4.39
Median Subgraph placement: 2.0
Subgraph is best choice: 7/33 shapes (21.2%)
Average improvement when best: 9.46%
Median improvement when best: 0.27%
Largest improvement when best: +38.24%
Operation: mm
----------------------------------------
Total shapes analyzed: 30
Average Subgraph placement: 7.63
Median Subgraph placement: 2.0
Subgraph is best choice: 10/30 shapes (33.3%)
Average improvement when best: 9.81%
Median improvement when best: 2.08%
Largest improvement when best: +38.05%
```
## Conclusion
Contiguous Subgraph Decompositionseems worthwhile for `mm` and `addmm`, but not `bmm`, and has a very large improvment on low `M`, low `N`, and high `K` shapes.
Data gathering scripts:
https://gist.github.com/exclamaforte/4a896c064d301b27bf5ca0a4f8fc3866
## Test Plan:
New unit tests.
Differential Revision: D80771648
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161241
Approved by: https://github.com/eellison
## Summary
Adds a subgraph decomposition for addmm and mm that performs well on large `K` compared to `M` and `N`, and functions well as an alternative to `split-k` on AMD (transposed only), which does not support AMD currently.
## Background
On AMD (MI300x), for a matmul A * B, if B is non-contiguous, the resulting matmul is quite a bit slower.
For example:
```
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[1, 178176]))
))
```
is a lot slower than:
```
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[6144, 1]))
))
```
This PR adds a subgraph decomposition to test out whether making B contiguous is faster than just using the normal kernels.
## Data
I ran this on unique non-contiguous shapes from torchbench/huggingface and got these speedups:
```
Parsed 420 unique shapes from benchmark output
addmm improvements when best:
addmm_16448x512x2048: +0.14%
addmm_128x2048x2048: +0.01%
addmm_128x768x1000: +0.75%
addmm_12672x3072x768: +1.08%
addmm_512x768x32000: +0.62%
addmm_12608x384x384: +0.00%
addmm_4160x1024x4096: +0.90%
addmm_16x768x2: +0.56%
addmm_12608x3072x768: +0.09%
addmm_64x4096x1000: +2.77%
addmm_256x1024x512: +1.99%
addmm_30x256x256: +1.12%
addmm_100480x128x384: +0.91%
addmm_6400x2048x512: +0.25%
addmm_61568x1024x256: +0.08%
addmm_1x768x768: +0.93%
addmm_12544x384x384: +0.19%
addmm_128x512x1000: +0.77%
addmm_2048x128x128: +1.32%
addmm_128x3072x1000: +0.24%
addmm_7936x512x2048: +0.07%
addmm_8192x512x2048: +0.33%
addmm_64x1024x1000: +1.43%
addmm_128x2304x1000: +0.01%
addmm_32768x256x2: +0.75%
addmm_64x384x1152: +0.79%
addmm_64x640x1000: +0.01%
addmm_100480x128x128: +0.87%
addmm_1152x3072x768: +1.13%
addmm_8192x256x2048: +1.40%
addmm_4096x128x768: +0.01%
addmm_128x2560x1000: +0.01%
addmm_12544x2048x512: +0.43%
addmm_200704x24x96: +0.14%
addmm_8448x512x2048: +0.96%
addmm_50176x256x1024: +0.62%
addmm_4160x4096x1024: +0.22%
addmm_4096x768x768: +0.32%
addmm_220x2048x512: +0.56%
addmm_8x2048x1000: +1.12%
addmm_256x197951x512: +26.99%
addmm_401536x64x192: +0.60%
addmm_2040x2048x512: +0.47%
addmm_512x1024x256: +1.32%
addmm_128x4096x1000: +1.67%
addmm_12672x768x768: +0.34%
addmm_128x368x1000: +0.77%
addmm_96x1280x1000: +0.01%
addmm_12544x512x2048: +0.41%
addmm_6272x320x1280: +0.76%
addmm_12544x3072x768: +0.09%
addmm_64x384x1000: +0.39%
mm improvements when best:
mm_200704x128x512: +1.29%
mm_663552x16x16: +0.80%
mm_4096x768x768: +0.51%
mm_131072x64x31: +0.24%
mm_12544x1152x384: +0.11%
mm_128x2048x2: +0.46%
mm_262144x16x23: +0.62%
mm_50176x576x192: +0.37%
mm_131072x16x31: +0.26%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================
Operation: addmm
----------------------------------------
Total shapes analyzed: 247
Average Subgraph placement: 3.38
Median Subgraph placement: 2.0
Subgraph is best choice: 52/247 shapes (21.1%)
Average improvement when best: 1.15%
Median improvement when best: 0.58%
Largest improvement when best: +26.99%
Operation: bmm
----------------------------------------
Total shapes analyzed: 85
Average Subgraph placement: 24.00
Median Subgraph placement: 21.0
Subgraph is best choice: 0/85 shapes (0.0%)
Average improvement when best: N/A (never best)
Median improvement when best: N/A (never best)
Largest improvement when best: N/A (never best)
Operation: mm
----------------------------------------
Total shapes analyzed: 88
Average Subgraph placement: 15.08
Median Subgraph placement: 4.0
Subgraph is best choice: 9/88 shapes (10.2%)
Average improvement when best: 0.52%
Median improvement when best: 0.46%
Largest improvement when best: +1.29%
```
## Results
The largest shape gain, `256,197951,512`, seemed to be driven by a case where the extern kernel is way faster than the best triton configs on the recursive autotune:
```
addmm,Extern,extern_kernels.addmm,256,197951,512,0.38024500012397766
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.005444049835205
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.04189395904541
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.1911399364471436
addmm,Triton,256,197951,512,64,128,32,2,4,8,2.496040105819702
addmm,Triton,256,197951,512,64,128,64,2,8,16,2.9306790828704834
addmm,Triton,256,197951,512,64,64,32,2,4,8,3.0347819328308105
...
```
Compared to the non-transposed autotune:
```
addmm,Subgraph,contiguous_addmm_1384,256,197951,512,0.5024129748344421
addmm,Extern,extern_kernels.addmm,256,197951,512,0.6881489753723145
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.5115010738372803
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.5167479515075684
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.9507460594177246
addmm,Triton,256,197951,512,64,256,64,2,8,4,2.9673290252685547
addmm,Triton,256,197951,512,64,128,64,2,8,16,3.3906331062316895
addmm,Triton,256,197951,512,64,128,32,2,4,8,3.496859073638916
```
It seems to perform really well for high values of `K` vs `N` and `M`.
Testing this hypothesis with some custom shapes:
```
Parsed 64 unique shapes from benchmark output
addmm improvements when best:
addmm_128x16384x128: +0.18%
addmm_128x262144x256: +38.24%
addmm_128x200000x512: +14.76%
addmm_256x800000x128: +0.06%
addmm_131072x128x256: +0.27%
addmm_128x256x131072: +0.25%
addmm_2048x200000x64: +12.45%
mm improvements when best:
mm_128x16384x128: +0.18%
mm_128x262144x256: +38.05%
mm_128x200000x512: +9.47%
mm_256x800000x128: +0.99%
mm_512x6400000x256: +3.17%
mm_524288x64x64: +0.29%
mm_2048x200000x64: +11.19%
mm_8192x1000000x256: +34.14%
mm_128x4096x100000: +0.40%
mm_128x3072x150000: +0.27%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================
Operation: addmm
----------------------------------------
Total shapes analyzed: 33
Average Subgraph placement: 4.39
Median Subgraph placement: 2.0
Subgraph is best choice: 7/33 shapes (21.2%)
Average improvement when best: 9.46%
Median improvement when best: 0.27%
Largest improvement when best: +38.24%
Operation: mm
----------------------------------------
Total shapes analyzed: 30
Average Subgraph placement: 7.63
Median Subgraph placement: 2.0
Subgraph is best choice: 10/30 shapes (33.3%)
Average improvement when best: 9.81%
Median improvement when best: 2.08%
Largest improvement when best: +38.05%
```
## Conclusion
Contiguous Subgraph Decompositionseems worthwhile for `mm` and `addmm`, but not `bmm`, and has a very large improvment on low `M`, low `N`, and high `K` shapes.
Data gathering scripts:
https://gist.github.com/exclamaforte/4a896c064d301b27bf5ca0a4f8fc3866
## Test Plan:
New unit tests.
Differential Revision: D80771648
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161241
Approved by: https://github.com/eellison
# why
- another step towards get_mm_configs providing
all the kwargs needed to add a choice from
a template. This in turn will allow us to send
all templates through one single call, and handle modifications
# what
- use the infrastructure for template heuristics to provide extra kwargs
that are fixed for a template/op pair to provide the suffix args
and epilogue function/fn for scaled_mm
# testing
```
python3 -bb -m pytest test/inductor/test_max_autotune.py -v
```
Differential Revision: [D80670914](https://our.internmc.facebook.com/intern/diff/D80670914)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161126
Approved by: https://github.com/jansel
ghstack dependencies: #161123, #161124, #161125
Summary:
Issue I noticed while fixing tests for TMA store. This triton.language.make_tensor_descriptor call hardcodes the shape information as the stride, which is not necessarily correct.
In particular, its legal to have a stride bigger than the shape (e.g. padded to a size). A good example of the usage of this would be to allocate a tensor to always be a multiple of 16 and just pad the result so TMA is legal.
This is redo of https://github.com/pytorch/pytorch/pull/160493 because I broke this accidentally trying to land internally first instead of merging through Github directly.
Test Plan:
Tested with `buck2 run mode/opt-split-dwarf mode/inplace -c fbcode.nvcc_arch=h100 caffe2/test/inductor:max_autotune 2>&1 | tee ~/test_logs.log` and confirmed all max autotune tests passed.
Rollback Plan:
Differential Revision: D80224578
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160614
Approved by: https://github.com/eellison
This is needed for subprocesses that are trying to call back into torch functionality, i.e. anything that's also setting `PYTHONPATH`. If they're part of an application that bundles the Python runtime, then they should use the bundled runtime to keep their view of the world consistent.
There are more `sys.executable` subprocesses in torch/ but it seems like they're fine.
Previous PR at https://github.com/pytorch/pytorch/pull/159382, but was reverted because it caused macOS jobs on GitHub to timeout. What was happening was inductor subprocesses were scheduling C++ compilation tasks that were failing to find the Python.h header. This was because they were running in venvs and now trying to find the CPython headers inside the venv, where the headers do not exist. This PR gates the new behavior to internal builds only.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160008
Approved by: https://github.com/aorenste
Summary: When compiling for standalone, make embed_kernel_binary and emit_multi_arch_kernel default to True, and add a default name for model_name_for_generated_files to make the generated cpp project easier to understand. Also improved the weights object file naming to be more readable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158560
Approved by: https://github.com/yushangdi
**Motivation / Context**: (what I _think_ is happening here)
In "eager"/just-in-time PT2 usage, dynamo/inductor will guard on whether indices fit in int32 or not. So it's generally safe in Inductor code to rely on the example values for symbolic ints in order to determine whether indices fit in int32, because the indices will be guarded on anyway; and if the inputs ever increase to `>int32_max`, dynamo will cause a recompilation.
But with AOTI, those int32 guards aren't respected; so if the example input is `< int32_max` but can be `> int32_max` during future execution, then the future execution might fail / IMA.
**Solution space**
Export allows users to specify which dimension are dynamic, and to provide **ranges of valid sizes**.
One solution idea is to always respect the upper bound of the dynamic shape range when doing AOTI; if the index's range includes values `>int32_max`, then don't use the hint and assume that this index doesn't fit in int32.
However, the problem with this is that many users may specify dynamism without specifying a range of values - the upper bound of the range will be set to the default of `inf`. Such use cases could potentially experience a perf regression if we implemented the idea above.
To prevent any such regressions, this implementation will rely solely on the specified range only if the upper bound of the range isn't inf. In other words, we'll ignore the hints/example values for AOTI (and rely only on the specified range) only if the upper bound of the range isn't inf - if users explicitly specify a range that extends past int32, we can be fairly sure that they actually do need values `>int32_max`.
If we continue to see correctness issues even with this implementation, we could consider more aggressively relying on the ranges.
Differential Revision: [D79220301](https://our.internmc.facebook.com/intern/diff/D79220301)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159433
Approved by: https://github.com/jingsh, https://github.com/ColinPeppler