Mutable custom operators get wrapped into an auto_functionalized HOP, so
we need to store the arg_kwarg_vals on the auto_functionalized HOP
itself.
When Inductor does the re-inplacing, it'll use the pattern matcher to
decompose the auto_functionalized HOP back into the original op (and
0+ other view or clone operations). The pattern matcher uses the
arg_kwarg_vals to trace the subgraph to do the decomposition, so it
ultimately sets arg_kwarg_vals on the original op's node correctly.
Test Plan:
- new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148091
Approved by: https://github.com/eellison
ghstack dependencies: #148046, #148063
Instead of always propagating arg_kwarg_vals in _COPY_META_FIELDS, we
special-case the pattern matcher to propagate arg_kwarg_vals when
it sees triton_kernel_wrapper_functional.
The strategy is:
1) trace out the replacement graph with arg_kwarg_vals (which have accurate eager-mode metadata)
2) trace out the replacement graph with vals (which have the accurate Inductor metadata)
3) Propagate the arg_kwarg_vals from the first graph to the second.
4) Use the second graph as the replacement graph.
The strategy is this because we want to extend this to handle
auto_functionalized later up in the stack.
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148046
Approved by: https://github.com/eellison
found an issue while running `python torchgen/fuse/gen_patterns.py`
exact error:
```shell
Traceback (most recent call last):
File "/Users/mayankmishra/Desktop/non-IBM/pytorch/torchgen/fuse/gen_patterns.py", line 19, in <module>
joint_graph.lazy_init()
File "/Users/mayankmishra/miniconda3/envs/ai/lib/python3.10/site-packages/torch/_inductor/pattern_matcher.py", line 2096, in lazy_init
result = fn()
File "/Users/mayankmishra/miniconda3/envs/ai/lib/python3.10/site-packages/torch/_inductor/fx_passes/joint_graph.py", line 53, in lazy_init
_pad_mm_init()
File "/Users/mayankmishra/miniconda3/envs/ai/lib/python3.10/site-packages/torch/_inductor/fx_passes/pad_mm.py", line 905, in _pad_mm_init
gen_register_replacement(
File "/Users/mayankmishra/miniconda3/envs/ai/lib/python3.10/site-packages/torch/_inductor/pattern_matcher.py", line 1584, in gen_register_replacement
pat = _serialize_pattern(
File "/Users/mayankmishra/miniconda3/envs/ai/lib/python3.10/site-packages/torch/_inductor/pattern_matcher.py", line 1539, in _serialize_pattern
file_template = get_file_template()
File "/Users/mayankmishra/miniconda3/envs/ai/lib/python3.10/site-packages/torch/_inductor/pattern_matcher.py", line 1513, in get_file_template
if isinstance(attr, type) and issubclass(attr, (PatternExpr, _TargetExpr)):
File "/Users/mayankmishra/miniconda3/envs/ai/lib/python3.10/abc.py", line 123, in __subclasscheck__
return _abc_subclasscheck(cls, subclass)
TypeError: issubclass() arg 1 must be a class
```
This PR fixes this issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147723
Approved by: https://github.com/aorenste
Co-authored-by: Aaron Orenstein <aorenste@meta.com>
…t of the pattern module
For example any pattern module that has the following pattern generated, fails to import because
the name getitem undefined.
native_dropout_default = CallFunction(aten.native_dropout.default, div_Tensor_1, KeywordArg('dropout_p'), True, _users=2)
getitem = CallFunction(getitem, native_dropout_default, 0)
this fix will resolve the error.
Fixes#144674
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144980
Approved by: https://github.com/eellison
Summary:
- use GraphTransformObserver + replace_node hooks to track node sources when they are replaced
- add pre_grad_graph tracking to tlparse
- add the node provenance information to post_grad_graph tlparse. This is for the frontend to create a mapping between pre_grad and post_grad graph. See an example frontend (this is just a prototype) here: https://drive.google.com/file/d/1cMHH_0y4FJUSS9tATwGQvA72O0Lth8eh/view?usp=sharing
- change "action" of NodeSource from a single action to a list of actions.
- It's BC-Breaking because we removed `GraphTransformObserver`'s class methods `on_node_erase` and `on_node_erase` .
https://docs.google.com/document/d/1dGh9myqNhywmbfP0Quzx_f04bghDFlj8cawj8MopiO8/edit?tab=t.0
The front-end code that takes in the tlparse result is in https://github.com/yushangdi/compiler_explorer.
ghstack-source-id: 260390519
Test Plan:
```
buck2 run mode/dev-nosan fbcode//caffe2/test:fx -- -r test_graph_transform_observer
buck run mode/dev-nosan fbcode//caffe2/test:fx -- -r node_source
buck run mode/dev-nosan fbcode//caffe2/test:fx -- -r graph_provenance
```
Front-end example screenshots on a real model, 93% coverage rate between pre_grad_graph and post_grad_graph
{F1973584210}{F1973584209}
```
buck2 build --show-output mode/opt -c=python.package_style=inplace -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.split-dwarf=true -c fbcode.nvcc_arch=a100,h100 caffe2/torch/fb/model_transform/experimental/benchmark:mts_gpu_benchmark
MODEL_ENTITY_ID=644688112
SNAPSHOT_ID=32
MODULE=merge
TORCH_COMPILE_DEBUG=1 CUDA_VISIBLE_DEVICES=7 TORCH_LOGS="+inductor,+schedule,output_code,graph_code" TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 ../buck-out/v2/gen/fbcode/ec86b05dd59e84db/caffe2/torch/fb/model_transform/experimental/benchmark/__mts_gpu_benchmark__/mts_gpu_benchmark.par --local-model /home/bahuang/models/${MODEL_ENTITY_ID}/${SNAPSHOT_ID}/gpu_lowering/input.predictor.disagg.gpu.merge --lower-backend AOT_INDUCTOR_EP --gpu-trace --aot-inductor-config="{'max_autotune':
True}"
buck2 run mode/dev-nosan fbcode//caffe2/test/inductor:auto_functionalize
```
Differential Revision: D65006709
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144277
Approved by: https://github.com/desertfire
vllm had an error when we were incorrectly stating two patterns are duplicates. See, comment inline:
For a particular generated pattern repr, store all the equivalent graphs that used to generate them. Because we ignore certain patterns in searching, but not in matching, use the graph to distinguish if two equivalent searches are actually different.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139321
Approved by: https://github.com/shunting314
vllm had an error when we were incorrectly stating two patterns are duplicates. See, comment inline:
For a particular generated pattern repr, store all the equivalent graphs that used to generate them. Because we ignore certain patterns in searching, but not in matching, use the graph to distinguish if two equivalent searches are actually different.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139321
Approved by: https://github.com/shunting314
Fixes#137280
When we have multiple indexings for the same array as returned items in pattern replacement, we shouldn't ignore its indexing numbers. otherwise, we may create a wrong pattern_to_node mapping.
A unit test is added in this PR. In this unit test, the function `rms_pattern_static` is replaced with `rms_replacement_static` when called. The function `rms_pattern_static` calls two functionalized custom operators, `torch.ops.vllm.rms_norm.default` and `torch.ops.vllm.static_scaled_int8_quant.default`, and it returns at2[1] and at2[2] as outputs. The function `rms_replacement_static` calls one functionalized custom operator `torch.ops.vllm.fused_rms_norm_quant_static.default`, which returns two corresponding items.
Run `python test/inductor/test_pattern_matcher.py -k test_multioutput_register_replacement` to test. After set `TORCH_COMPILE_DEBUG` to 1, the final part of the `fx_graph_readable.py` is like the following.
```python
# File: /home/yhao/p9/pytorch/test/inductor/test_pattern_matcher.py:1673 in rms_pattern_static, code: at1 = auto_functionalized(
auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.rms_norm.default, result = permute_1, input = convert_element_type, weight = convert_element_type_1, epsilon = 1e-06); permute_1 = convert_element_type = convert_element_type_1 = None
getitem_1: "bf16[5, 4]" = auto_functionalized[1]; auto_functionalized = None
# File: /home/yhao/p9/pytorch/test/inductor/test_pattern_matcher.py:1680 in rms_pattern_static, code: at2 = auto_functionalized(
auto_functionalized_1 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.static_scaled_int8_quant.default, result = permute, input = getitem_1, scale = full_default, azp = None); permute = getitem_1 = full_default = None
getitem_3: "i8[5, 4]" = auto_functionalized_1[1]
getitem_4: "f32[1, 1]" = auto_functionalized_1[2]; auto_functionalized_1 = None
return (getitem_3, getitem_4)
```
This happens before pattern matching, so is it expected to call `static_scaled_int8_quant` and `rms_norm` and return auto_functionalized_1 as outputs.
However, for pytorch before this PR, the `fx_graph_transformed.py`, which is after pattern matching, has the following code.
```python
# File: /home/yhao/p9/pytorch/test/inductor/test_pattern_matcher.py:1748 in my_func_static, code: scale = torch.ones((1, 1))
full_default: "f32[1, 1]" = torch.ops.aten.full.default([1, 1], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
# No stacktrace found for following nodes
as_strided_default: "i8[20]" = torch.ops.aten.as_strided.default(permute, [20], [1], 0)
clone_default: "i8[20]" = torch.ops.aten.clone.default(as_strided_default); as_strided_default = None
as_strided_default_1: "i8[5, 4]" = torch.ops.aten.as_strided.default(clone_default, [5, 4], [4, 1], 0); clone_default = None
as_strided_default_2: "f32[1]" = torch.ops.aten.as_strided.default(full_default, [1], [1], 0)
clone_default_1: "f32[1]" = torch.ops.aten.clone.default(as_strided_default_2); as_strided_default_2 = None
as_strided_default_3: "f32[1, 1]" = torch.ops.aten.as_strided.default(clone_default_1, [1, 1], [1, 1], 0); clone_default_1 = None
static_scaled_int8_quant_default = torch.ops.vllm.static_scaled_int8_quant.default(as_strided_default_1, permute_1, as_strided_default_3); as_strided_default_1 = permute_1 = static_scaled_int8_quant_default = None
fused_rms_norm_quant_static_default = torch.ops.vllm.fused_rms_norm_quant_static.default(permute, convert_element_type, convert_element_type_1, full_default, None, 1e-06); convert_element_type = convert_element_type_1 = full_default = fused_rms_norm_quant_static_default = None
return (permute, as_strided_default_3)
```
Here, it returns `(permute, as_strided_default_3)` while `permute` is written by fused_rms_norm_quant_static and `as_strided_default_3` is written by `static_scaled_int8_quant`. This is wrong because in our expectation, the `static_scaled_int8_quant` should be removed since it is replaced with `fused_rms_norm_quant_static`. It is supposed to return `(permute, full_default)`.
The root cause is the following part. When we [generate patterns](5f4a21dc58/torch/_inductor/pattern_matcher.py (L1580)) with traced fx graph and call the following function, the indexing numbers' type int in traced graph are ignored in `ignore_types`. So, the final arguments of patterns for those two output items are like `(CallFunction(auto_functionalized,XXX)), *)`.
5f4a21dc58/torch/_inductor/pattern_matcher.py (L1839-L1847)
When we do pattern matching after we generated patterns in the following part, the `sorted(itertools.chain.from_iterable(nodes), reverse=True)` is `[getitem_4, getitem_3, getitem_1]`. The getitem_4's iteration is always FailedMatch because we always use the first element to do the pattern match here (it fails on different match functions before and after this PR, but the reason is always the indexing numbers issue)d4cdc09881/torch/_inductor/pattern_matcher.py (L848). However, when we do pattern matching for getitem_3, the child_match returns a match for getitem_3 again which is because the `*` pattern can match anything. Then the getitem_3's pattern matching returns a `[getitem_3, getitem_3]` as outputs which are wrong.
d4cdc09881/torch/_inductor/pattern_matcher.py (L856)d4cdc09881/torch/_inductor/pattern_matcher.py (L1750-L1774)
This PR doesn't ignore `int` type when we generate patterns for getitem functions because integer indexing numbers are important to them. Thus, the indexing information is kept in patterns, ensuring correct matchings. With this PR, the above `child_match` returns a match for getitem_4, and the final getitem_3's pattern matching returns the correct `[getitem_3, getitem_4]`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140193
Approved by: https://github.com/eellison
* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980
Approved by: https://github.com/justinchuby, https://github.com/malfet
replace_by_example is used to implement some pattern-matching passes in inductor. Previously, replace_by_example would generate nodes with very little metadata. In particular, `meta["original_aten"]` would be lost; that meant that when generating triton kernel names, you could get empty names like `triton_tem_fused_0` if the input nodes to the fused kernel were the result of a pattern-matching pass that used replace_by_example.
This also adds metadata for to register_replacement patterns, including pad_mm.
This fixes the issue by copying metadata from the original node to the replacement nodes. If there are multiple original nodes we skip the metadata transfer; so if you have a `add(z, mm(x, y))`, then the metadata won't be transferred right now.
Differential Revision: [D64480755](https://our.internmc.facebook.com/intern/diff/D64480755)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138089
Approved by: https://github.com/aakhundov
replace_by_example is used to implement some pattern-matching passes in inductor. Previously, replace_by_example would generate nodes with very little metadata. In particular, `meta["original_aten"]` would be lost; that meant that when generating triton kernel names, you could get empty names like `triton_tem_fused_0` if the input nodes to the fused kernel were the result of a pattern-matching pass that used replace_by_example.
This also adds metadata for to register_replacement patterns, including pad_mm.
This fixes the issue by copying metadata from the original node to the replacement nodes. If there are multiple original nodes we skip the metadata transfer; so if you have a `add(z, mm(x, y))`, then the metadata won't be transferred right now.
Differential Revision: [D64480755](https://our.internmc.facebook.com/intern/diff/D64480755)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138089
Approved by: https://github.com/aakhundov
The pattern matcher runs DCE and remove_noop_ops on the replacement
graph by default. Previously we had a switch for the DCE. This PR
changes that switch to also control if we run remove_noop_ops.
The context was that there is silent incorrectness with
auto_functionalized. We use the Pattern matcher to decompose
auto_functionalized into a mutable op + clones; remove_noop_ops were
deleting the clones.
Future: can try #134363
Test Plan:
- new test. I wasn't able to produce a silently incorrect example so I
settled for asserting that clones still exist in the post-grad graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134364
Approved by: https://github.com/eellison
ghstack dependencies: #133639
Part of #134054.
This corresponds to the pytorch mypy changes from D61493706. Updating takes so
long and touches so many files that it's impossible to land as a whole without conflicting with some other intermediate change.
So landing these 'type: ignore' for pytorch in advance of them actually being needed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134202
Approved by: https://github.com/Skylion007