Commit Graph

178 Commits

Author SHA1 Message Date
e95e8eed0a mypy 1.16.0 (#155821)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155821
Approved by: https://github.com/ezyang, https://github.com/zou3519
2025-06-14 18:18:43 +00:00
d1947a8707 Migrate from lru_cache to cache (#155613)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155613
Approved by: https://github.com/ezyang
ghstack dependencies: #155612
2025-06-11 19:44:18 +00:00
0827464002 Replace runtime type parameterization (#155221)
See:

```
>>> import timeit; print(f"OrderedSet[str](): {timeit.timeit('OrderedSet[str]()', setup='from torch.utils._ordered_set import OrderedSet', number=1000000):.6f}s, OrderedSet(): {timeit.timeit('OrderedSet()', setup='from torch.utils._ordered_set import OrderedSet', number=1000000):.6f}s")
```
> `OrderedSet[str]()`: 0.354622s, OrderedSet(): 0.095376s

Type parameterization should be on type hint, not in runtime.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155221
Approved by: https://github.com/Skylion007, https://github.com/jansel
2025-06-05 21:43:54 +00:00
aae36929ed Rename node.meta["arg_kwarg_vals"] to node.meta["eager_input_vals"] (#148092)
And added a comment about it. Otherwise it might be confusing

Test Plan:
- wait for CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148092
Approved by: https://github.com/eellison
ghstack dependencies: #148046, #148063, #148091
2025-04-02 13:18:04 +00:00
4d121d2b02 Implement needs_exact_strides for mutable custom operators (#148091)
Mutable custom operators get wrapped into an auto_functionalized HOP, so
we need to store the arg_kwarg_vals on the auto_functionalized HOP
itself.

When Inductor does the re-inplacing, it'll use the pattern matcher to
decompose the auto_functionalized HOP back into the original op (and
0+ other view or clone operations). The pattern matcher uses the
arg_kwarg_vals to trace the subgraph to do the decomposition, so it
ultimately sets arg_kwarg_vals on the original op's node correctly.

Test Plan:
- new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148091
Approved by: https://github.com/eellison
ghstack dependencies: #148046, #148063
2025-04-02 13:18:04 +00:00
c41fbb4f78 Change arg_kwarg_vals propagation strategy (#148046)
Instead of always propagating arg_kwarg_vals in _COPY_META_FIELDS, we
special-case the pattern matcher to propagate arg_kwarg_vals when
it sees triton_kernel_wrapper_functional.

The strategy is:
1) trace out the replacement graph with arg_kwarg_vals (which have accurate eager-mode metadata)
2) trace out the replacement graph with vals (which have the accurate Inductor metadata)
3) Propagate the arg_kwarg_vals from the first graph to the second.
4) Use the second graph as the replacement graph.

The strategy is this because we want to extend this to handle
auto_functionalized later up in the stack.

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148046
Approved by: https://github.com/eellison
2025-04-02 13:17:52 +00:00
4a4a71a73c [inductor]lowering scan to while_loop (#148580)
This PR add a pass in post_grad that lowers scan to while_loop. See the comment before the pass for how this is implemented.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148580
Approved by: https://github.com/jansel, https://github.com/eellison
2025-03-20 20:21:02 +00:00
37c914ca0c fix simple-spec crash (#147723)
found an issue while running `python torchgen/fuse/gen_patterns.py`

exact error:
```shell
Traceback (most recent call last):
  File "/Users/mayankmishra/Desktop/non-IBM/pytorch/torchgen/fuse/gen_patterns.py", line 19, in <module>
    joint_graph.lazy_init()
  File "/Users/mayankmishra/miniconda3/envs/ai/lib/python3.10/site-packages/torch/_inductor/pattern_matcher.py", line 2096, in lazy_init
    result = fn()
  File "/Users/mayankmishra/miniconda3/envs/ai/lib/python3.10/site-packages/torch/_inductor/fx_passes/joint_graph.py", line 53, in lazy_init
    _pad_mm_init()
  File "/Users/mayankmishra/miniconda3/envs/ai/lib/python3.10/site-packages/torch/_inductor/fx_passes/pad_mm.py", line 905, in _pad_mm_init
    gen_register_replacement(
  File "/Users/mayankmishra/miniconda3/envs/ai/lib/python3.10/site-packages/torch/_inductor/pattern_matcher.py", line 1584, in gen_register_replacement
    pat = _serialize_pattern(
  File "/Users/mayankmishra/miniconda3/envs/ai/lib/python3.10/site-packages/torch/_inductor/pattern_matcher.py", line 1539, in _serialize_pattern
    file_template = get_file_template()
  File "/Users/mayankmishra/miniconda3/envs/ai/lib/python3.10/site-packages/torch/_inductor/pattern_matcher.py", line 1513, in get_file_template
    if isinstance(attr, type) and issubclass(attr, (PatternExpr, _TargetExpr)):
  File "/Users/mayankmishra/miniconda3/envs/ai/lib/python3.10/abc.py", line 123, in __subclasscheck__
    return _abc_subclasscheck(cls, subclass)
TypeError: issubclass() arg 1 must be a class
```

This PR fixes this issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147723
Approved by: https://github.com/aorenste

Co-authored-by: Aaron Orenstein <aorenste@meta.com>
2025-03-17 23:25:48 +00:00
6b44a91a62 use statically_known_true instead of guard_size_oblivious in pattern matcher (#147557)
We shouldn't add guards here. Use statically_known_true instead. Internal xref: https://fb.workplace.com/groups/1075192433118967/?multi_permalinks=1609560723015466&comment_id=1610040026300869&notif_id=1740082892544333&notif_t=work_feedback_reaction_generic&ref=notif

Differential Revision: [D69950122](https://our.internmc.facebook.com/intern/diff/D69950122/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147557
Approved by: https://github.com/eellison
2025-03-07 19:17:25 +00:00
1cb4e2df65 [BE][PYFMT] migrate PYFMT for torch._inductor to ruff format (#144550)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144550
Approved by: https://github.com/jansel
2025-02-28 13:33:19 +00:00
80d3afc698 [inductor] Improve type annotations in _inductor/pattern_matcher.py (#146626)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146626
Approved by: https://github.com/Skylion007
2025-02-24 14:30:35 +00:00
a50af71fb6 [FX] Refactor immutable collections implementation (#144640)
Get rid of dynamic class creation via `type(name, bases, ...)`. Convert it to classic static class definition for better readability and static analysis support.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144640
Approved by: https://github.com/jansel
ghstack dependencies: #147699
2025-02-24 09:14:08 +00:00
880e176544 [inductor] Fix for pattern file contains 'getitem' fails during impor… (#144980)
…t of the pattern module

  For example any pattern module that has the following pattern generated, fails to import because
  the name getitem undefined.

  native_dropout_default = CallFunction(aten.native_dropout.default, div_Tensor_1, KeywordArg('dropout_p'), True, _users=2)
  getitem = CallFunction(getitem, native_dropout_default, 0)

  this fix will resolve the error.

Fixes #144674

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144980
Approved by: https://github.com/eellison
2025-02-14 02:30:24 +00:00
bac62341eb PEP585 update - torch/_inductor (#145198)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145198
Approved by: https://github.com/bobrenjc93
2025-01-21 21:04:33 +00:00
379b54603a [Inductor] [bc-breaking] Node Level provenance tracking (#144277)
Summary:

- use GraphTransformObserver + replace_node hooks to track node sources when they are replaced
- add pre_grad_graph tracking to tlparse
- add the node provenance information to post_grad_graph tlparse. This is for the frontend to create a mapping between pre_grad and post_grad graph. See an example frontend (this is just a prototype) here:  https://drive.google.com/file/d/1cMHH_0y4FJUSS9tATwGQvA72O0Lth8eh/view?usp=sharing
- change "action" of NodeSource from a single action to a list of actions.

- It's BC-Breaking because we removed `GraphTransformObserver`'s class methods `on_node_erase` and `on_node_erase` .

https://docs.google.com/document/d/1dGh9myqNhywmbfP0Quzx_f04bghDFlj8cawj8MopiO8/edit?tab=t.0

The front-end code that takes in the tlparse result is in https://github.com/yushangdi/compiler_explorer.
ghstack-source-id: 260390519

Test Plan:
```
buck2 run mode/dev-nosan fbcode//caffe2/test:fx -- -r test_graph_transform_observer
buck run mode/dev-nosan  fbcode//caffe2/test:fx -- -r node_source
buck run mode/dev-nosan  fbcode//caffe2/test:fx -- -r graph_provenance
```

Front-end example screenshots on a real model, 93% coverage rate between pre_grad_graph and post_grad_graph

 {F1973584210}{F1973584209}

```
buck2 build --show-output mode/opt -c=python.package_style=inplace -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.split-dwarf=true -c fbcode.nvcc_arch=a100,h100 caffe2/torch/fb/model_transform/experimental/benchmark:mts_gpu_benchmark

MODEL_ENTITY_ID=644688112
SNAPSHOT_ID=32
MODULE=merge

TORCH_COMPILE_DEBUG=1 CUDA_VISIBLE_DEVICES=7 TORCH_LOGS="+inductor,+schedule,output_code,graph_code" TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 ../buck-out/v2/gen/fbcode/ec86b05dd59e84db/caffe2/torch/fb/model_transform/experimental/benchmark/__mts_gpu_benchmark__/mts_gpu_benchmark.par --local-model /home/bahuang/models/${MODEL_ENTITY_ID}/${SNAPSHOT_ID}/gpu_lowering/input.predictor.disagg.gpu.merge --lower-backend AOT_INDUCTOR_EP --gpu-trace --aot-inductor-config="{'max_autotune':
True}"

buck2 run mode/dev-nosan fbcode//caffe2/test/inductor:auto_functionalize
```

Differential Revision: D65006709

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144277
Approved by: https://github.com/desertfire
2025-01-09 22:06:51 +00:00
a3ab27b8e0 Migrate from Tuple -> tuple in torch/_inductor (#144264)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144264
Approved by: https://github.com/eellison
2025-01-07 03:27:27 +00:00
e1622dca7a Fix duplicate pattern error (#139321)
vllm had an error when we were incorrectly stating two patterns are duplicates. See, comment inline:

For a particular generated pattern repr, store all the equivalent graphs that used to generate them. Because we ignore certain patterns in searching, but not in matching, use the graph to distinguish if two equivalent searches are actually different.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139321
Approved by: https://github.com/shunting314
2025-01-06 05:04:59 +00:00
99f2491af9 Revert "Use absolute path path.resolve() -> path.absolute() (#129409)"
This reverts commit 45411d1fc9a2b6d2f891b6ab0ae16409719e09fc.

Reverted https://github.com/pytorch/pytorch/pull/129409 on behalf of https://github.com/jeanschmidt due to Breaking internal CI, @albanD please help get this PR merged ([comment](https://github.com/pytorch/pytorch/pull/129409#issuecomment-2571316444))
2025-01-04 14:17:20 +00:00
45411d1fc9 Use absolute path path.resolve() -> path.absolute() (#129409)
Changes:

1. Always explicit `.absolute()`: `Path(__file__)` -> `Path(__file__).absolute()`
2. Replace `path.resolve()` with `path.absolute()` if the code is resolving the PyTorch repo root directory.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129409
Approved by: https://github.com/albanD
2025-01-03 20:03:40 +00:00
a174ee2255 Revert "Fix duplicate pattern error (#139321)"
This reverts commit 9e8d84f8631317ce61de4f0f9731fc1b1ccc3d2b.

Reverted https://github.com/pytorch/pytorch/pull/139321 on behalf of https://github.com/jeanschmidt due to breaking internal signals ([comment](https://github.com/pytorch/pytorch/pull/139321#issuecomment-2566620095))
2024-12-31 17:44:02 +00:00
9e8d84f863 Fix duplicate pattern error (#139321)
vllm had an error when we were incorrectly stating two patterns are duplicates. See, comment inline:

For a particular generated pattern repr, store all the equivalent graphs that used to generate them. Because we ignore certain patterns in searching, but not in matching, use the graph to distinguish if two equivalent searches are actually different.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139321
Approved by: https://github.com/shunting314
2024-12-27 11:10:46 +00:00
cc4e70b7c3 Revert "Use absolute path path.resolve() -> path.absolute() (#129409)"
This reverts commit 135c7db99d646b8bd9603bf969d47d3dec5987b1.

Reverted https://github.com/pytorch/pytorch/pull/129409 on behalf of https://github.com/malfet due to need to revert to as dependency of https://github.com/pytorch/pytorch/pull/129374 ([comment](https://github.com/pytorch/pytorch/pull/129409#issuecomment-2562969825))
2024-12-26 17:26:06 +00:00
135c7db99d Use absolute path path.resolve() -> path.absolute() (#129409)
Changes:

1. Always explicit `.absolute()`: `Path(__file__)` -> `Path(__file__).absolute()`
2. Replace `path.resolve()` with `path.absolute()` if the code is resolving the PyTorch repo root directory.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129409
Approved by: https://github.com/albanD
2024-12-24 08:33:08 +00:00
dec4286b2d [inductor] Fix for extract_target with dots (#143766)
Fixes #143650

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143766
Approved by: https://github.com/yanboliang
2024-12-24 03:42:15 +00:00
da67a6a7bb [inductor] Replace set by OrderedSet (#138466)
Uses the set_linter from https://github.com/pytorch/pytorch/pull/138454
and considerable manual editing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138466
Approved by: https://github.com/eellison
2024-12-13 16:08:45 +00:00
ab63b679e9 Save indexing for getitem nodes when do custom replacements (#140193)
Fixes #137280

When we have multiple indexings for the same array as returned items in pattern replacement, we shouldn't ignore its indexing numbers. otherwise, we may create a wrong pattern_to_node mapping.

A unit test is added in this PR. In this unit test, the function `rms_pattern_static` is replaced with `rms_replacement_static` when called. The function `rms_pattern_static` calls two functionalized custom operators, `torch.ops.vllm.rms_norm.default` and `torch.ops.vllm.static_scaled_int8_quant.default`, and it returns at2[1] and at2[2] as outputs. The function `rms_replacement_static` calls one functionalized custom operator `torch.ops.vllm.fused_rms_norm_quant_static.default`, which returns two corresponding items.

Run `python test/inductor/test_pattern_matcher.py -k test_multioutput_register_replacement` to test. After set `TORCH_COMPILE_DEBUG` to 1, the final part of the `fx_graph_readable.py` is like the following.
```python
# File: /home/yhao/p9/pytorch/test/inductor/test_pattern_matcher.py:1673 in rms_pattern_static, code: at1 = auto_functionalized(
auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.rms_norm.default, result = permute_1, input = convert_element_type, weight = convert_element_type_1, epsilon = 1e-06);  permute_1 = convert_element_type = convert_element_type_1 = None
getitem_1: "bf16[5, 4]" = auto_functionalized[1];  auto_functionalized = None

# File: /home/yhao/p9/pytorch/test/inductor/test_pattern_matcher.py:1680 in rms_pattern_static, code: at2 = auto_functionalized(
auto_functionalized_1 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.static_scaled_int8_quant.default, result = permute, input = getitem_1, scale = full_default, azp = None);  permute = getitem_1 = full_default = None
getitem_3: "i8[5, 4]" = auto_functionalized_1[1]
getitem_4: "f32[1, 1]" = auto_functionalized_1[2];  auto_functionalized_1 = None
return (getitem_3, getitem_4)
```
This happens before pattern matching, so is it expected to call `static_scaled_int8_quant` and `rms_norm` and return auto_functionalized_1 as outputs.

However, for pytorch before this PR, the `fx_graph_transformed.py`, which is after pattern matching, has the following code.
```python
 # File: /home/yhao/p9/pytorch/test/inductor/test_pattern_matcher.py:1748 in my_func_static, code: scale = torch.ones((1, 1))
full_default: "f32[1, 1]" = torch.ops.aten.full.default([1, 1], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)

# No stacktrace found for following nodes
as_strided_default: "i8[20]" = torch.ops.aten.as_strided.default(permute, [20], [1], 0)
clone_default: "i8[20]" = torch.ops.aten.clone.default(as_strided_default);  as_strided_default = None
as_strided_default_1: "i8[5, 4]" = torch.ops.aten.as_strided.default(clone_default, [5, 4], [4, 1], 0);  clone_default = None
as_strided_default_2: "f32[1]" = torch.ops.aten.as_strided.default(full_default, [1], [1], 0)
clone_default_1: "f32[1]" = torch.ops.aten.clone.default(as_strided_default_2);  as_strided_default_2 = None
as_strided_default_3: "f32[1, 1]" = torch.ops.aten.as_strided.default(clone_default_1, [1, 1], [1, 1], 0);  clone_default_1 = None
static_scaled_int8_quant_default = torch.ops.vllm.static_scaled_int8_quant.default(as_strided_default_1, permute_1, as_strided_default_3);  as_strided_default_1 = permute_1 = static_scaled_int8_quant_default = None
fused_rms_norm_quant_static_default = torch.ops.vllm.fused_rms_norm_quant_static.default(permute, convert_element_type, convert_element_type_1, full_default, None, 1e-06);  convert_element_type = convert_element_type_1 = full_default = fused_rms_norm_quant_static_default = None
return (permute, as_strided_default_3)
```
Here, it returns `(permute, as_strided_default_3)` while `permute` is written by fused_rms_norm_quant_static and `as_strided_default_3` is written by `static_scaled_int8_quant`. This is wrong because in our expectation, the `static_scaled_int8_quant` should be removed since it is replaced with `fused_rms_norm_quant_static`. It is supposed to return `(permute, full_default)`.

The root cause is the following part. When we [generate patterns](5f4a21dc58/torch/_inductor/pattern_matcher.py (L1580)) with traced fx graph and call the following function, the indexing numbers' type int in traced graph are ignored in `ignore_types`. So, the final arguments of patterns for those two output items are like `(CallFunction(auto_functionalized,XXX)), *)`.

5f4a21dc58/torch/_inductor/pattern_matcher.py (L1839-L1847)

When we do pattern matching after we generated patterns in the following part, the `sorted(itertools.chain.from_iterable(nodes), reverse=True)` is `[getitem_4, getitem_3, getitem_1]`. The getitem_4's iteration is always FailedMatch because we always use the first element to do the pattern match here (it fails on different match functions before and after this PR, but the reason is always the indexing numbers issue)d4cdc09881/torch/_inductor/pattern_matcher.py (L848). However, when we do pattern matching for getitem_3, the child_match returns a match for getitem_3 again which is because the `*` pattern can match anything. Then the getitem_3's pattern matching returns a `[getitem_3, getitem_3]` as outputs which are wrong.
d4cdc09881/torch/_inductor/pattern_matcher.py (L856)

d4cdc09881/torch/_inductor/pattern_matcher.py (L1750-L1774)

This PR doesn't ignore `int` type when we generate patterns for getitem functions because integer indexing numbers are important to them. Thus, the indexing information is kept in patterns, ensuring correct matchings. With this PR, the above `child_match` returns a match for getitem_4, and the final getitem_3's pattern matching returns the correct `[getitem_3, getitem_4]`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140193
Approved by: https://github.com/eellison
2024-11-27 20:19:13 +00:00
612122af8f Fix type-safety of torch.nn.Module instances (#141240)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141240
Approved by: https://github.com/Skylion007, https://github.com/malfet
2024-11-22 00:05:05 +00:00
12e95aa4ee [BE]: Apply PERF401 autofixes from ruff (#140980)
* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980
Approved by: https://github.com/justinchuby, https://github.com/malfet
2024-11-20 17:52:07 +00:00
f93ebb2cf4 [Easy] Refactor post grad application of passes (#139293)
Refactors GraphTransformObserver to hook into the bisect manager pass application. And reworks post grad passes to use it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139293
Approved by: https://github.com/exclamaforte
ghstack dependencies: #139292
2024-10-31 17:05:27 +00:00
4db6b740bc [Easy] GraphTransformObserver Refactoring (#139292)
Uses `torch._inductor.config.trace.log_url_for_graph_xform` by default as the log url. It was only ever instantiated with this as the log_url argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139292
Approved by: https://github.com/shengfukevin, https://github.com/shunting314
2024-10-31 00:33:28 +00:00
49ed365b22 [BE]: Update Typeguard to TypeIs for better type inference (#133814)
Uses TypeIs instead of TypeGuard for better inference. See https://peps.python.org/pep-0742/

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133814
Approved by: https://github.com/ezyang
2024-10-26 15:07:13 +00:00
86d4b7d60b [FX][export][dynamo] use tuple instead of list in normalized args_spec (#138212)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138212
Approved by: https://github.com/jansel
2024-10-25 06:43:55 +00:00
32d4582e02 Revert "[BE]: Update Typeguard to TypeIs for better type inference (#133814)"
This reverts commit 16caa8c1b3a02e47b5f52d3c2d40d7931cc427dc.

Reverted https://github.com/pytorch/pytorch/pull/133814 on behalf of https://github.com/jeanschmidt due to checking if this will solve inductor errors ([comment](https://github.com/pytorch/pytorch/pull/133814#issuecomment-2427565425))
2024-10-21 19:40:58 +00:00
16caa8c1b3 [BE]: Update Typeguard to TypeIs for better type inference (#133814)
Uses TypeIs instead of TypeGuard for better inference. See https://peps.python.org/pep-0742/

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133814
Approved by: https://github.com/ezyang
2024-10-21 17:20:06 +00:00
701ddf962a [inductor] Preserve metadata across replace_by_example and register_replacement patterns (#138089)
replace_by_example is used to implement some pattern-matching passes in inductor. Previously, replace_by_example would generate nodes with very little metadata. In particular, `meta["original_aten"]` would be lost; that meant that when generating triton kernel names, you could get empty names like `triton_tem_fused_0` if the input nodes to the fused kernel were the result of a pattern-matching pass that used replace_by_example.

This also adds metadata for to register_replacement patterns, including pad_mm.

This fixes the issue by copying metadata from the original node to the replacement nodes. If there are multiple original nodes we skip the metadata transfer; so if you have a `add(z, mm(x, y))`, then the metadata won't be transferred right now.

Differential Revision: [D64480755](https://our.internmc.facebook.com/intern/diff/D64480755)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138089
Approved by: https://github.com/aakhundov
2024-10-21 16:33:12 +00:00
47e80abc7a Revert "[inductor] Preserve metadata across replace_by_example and register_replacement patterns (#138089)"
This reverts commit fb44658415e50b5be6a187ff3f14243c0fdf3daf.

Reverted https://github.com/pytorch/pytorch/pull/138089 on behalf of https://github.com/huydhn due to Sorry for reverting your PR but the new test_original_aten_preserved_pad_mm test runs OOM in trunk fb44658415 ([comment](https://github.com/pytorch/pytorch/pull/138089#issuecomment-2424297269))
2024-10-19 23:55:01 +00:00
fb44658415 [inductor] Preserve metadata across replace_by_example and register_replacement patterns (#138089)
replace_by_example is used to implement some pattern-matching passes in inductor. Previously, replace_by_example would generate nodes with very little metadata. In particular, `meta["original_aten"]` would be lost; that meant that when generating triton kernel names, you could get empty names like `triton_tem_fused_0` if the input nodes to the fused kernel were the result of a pattern-matching pass that used replace_by_example.

This also adds metadata for to register_replacement patterns, including pad_mm.

This fixes the issue by copying metadata from the original node to the replacement nodes. If there are multiple original nodes we skip the metadata transfer; so if you have a `add(z, mm(x, y))`, then the metadata won't be transferred right now.

Differential Revision: [D64480755](https://our.internmc.facebook.com/intern/diff/D64480755)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138089
Approved by: https://github.com/aakhundov
2024-10-19 16:37:08 +00:00
8184e202d8 Update mutation checking in pattern matcher (#137448)
Fix for https://github.com/pytorch/pytorch/issues/137229

The current mutation checking is complicated because it works for pre grad IR. When pre grad ir has been traced to OpOverloads checking is much easier. I am also special casing auto functional hop although I discussed with @zou3519 it would be nice to have a way of querying HOPs that mimic schemas.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137448
Approved by: https://github.com/zou3519
2024-10-08 16:56:40 +00:00
193c547461 [inductor] Refactor simplify erase_nodes() (#134822)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134822
Approved by: https://github.com/shunting314
ghstack dependencies: #134748, #134749
2024-09-04 17:32:07 +00:00
f5b0caee71 Rewrite unsafe_remove_auto_functionalized_pass using decompose_auto_functionalized (#134831)
`unsafe_remove_auto_functionalized_pass` can be written as using `decompose_auto_functionalized`, this way we do not have to update it each time we do a change to `auto_functionalize` (Ex https://github.com/pytorch/pytorch/pull/134409) , and we avoid duplicate logics implemented in two different ways.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134831
Approved by: https://github.com/zou3519
2024-08-30 16:27:53 +00:00
a1d0b4d568 Add option to skip functional passes in the pattern matcher's replacement graph (#134364)
The pattern matcher runs DCE and remove_noop_ops on the replacement
graph by default. Previously we had a switch for the DCE. This PR
changes that switch to also control if we run remove_noop_ops.

The context was that there is silent incorrectness with
auto_functionalized. We use the Pattern matcher to decompose
auto_functionalized into a mutable op + clones; remove_noop_ops were
deleting the clones.

Future: can try #134363

Test Plan:
- new test. I wasn't able to produce a silently incorrect example so I
  settled for asserting that clones still exist in the post-grad graph.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134364
Approved by: https://github.com/eellison
ghstack dependencies: #133639
2024-08-24 00:38:55 +00:00
d95aedf5fd [BE] typing for decorators - fx/_compatibility (part 1) (#134202)
Part of #134054.

This corresponds to the pytorch mypy changes from D61493706. Updating takes so
long and touches so many files that it's impossible to land as a whole without conflicting with some other intermediate change.
So landing these 'type: ignore' for pytorch in advance of them actually being needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134202
Approved by: https://github.com/Skylion007
2024-08-22 17:07:33 +00:00
2db28a9611 Revert "[BE]: Update Typeguard to TypeIs for better type inference (#133814)"
This reverts commit bce0caba7804b0787684dbf1f4e1c4d9e3acded5.

Reverted https://github.com/pytorch/pytorch/pull/133814 on behalf of https://github.com/ezyang due to root cause of internal failures not addressed ([comment](https://github.com/pytorch/pytorch/pull/133814#issuecomment-2302466444))
2024-08-21 16:13:34 +00:00
bce0caba78 [BE]: Update Typeguard to TypeIs for better type inference (#133814)
Uses TypeIs instead of TypeGuard for better inference. See https://peps.python.org/pep-0742/

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133814
Approved by: https://github.com/ezyang
2024-08-20 17:19:57 +00:00
42097f0ec1 Revert "[BE]: Update Typeguard to TypeIs for better type inference (#133814)"
This reverts commit cf60fe53a83bafec0857d5b49c2054de6ba4cddc.

Reverted https://github.com/pytorch/pytorch/pull/133814 on behalf of https://github.com/jeanschmidt due to Broke 12k internal signals/jobs, @ezyang please help get those changes merged. More details check D61488368 ([comment](https://github.com/pytorch/pytorch/pull/133814#issuecomment-2298210309))
2024-08-20 08:02:49 +00:00
cf60fe53a8 [BE]: Update Typeguard to TypeIs for better type inference (#133814)
Uses TypeIs instead of TypeGuard for better inference. See https://peps.python.org/pep-0742/

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133814
Approved by: https://github.com/ezyang
2024-08-18 19:10:16 +00:00
1f66487c69 [BE] Reroute all uses of proxy_tensor.maybe_disable_fake_tensor_mode to fake_tensor.unset_fake_temporarily (#132770)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132770
Approved by: https://github.com/bdhirsh
2024-08-08 23:07:23 +00:00
d1f73fd844 Revert "[BE] Reroute all uses of proxy_tensor.maybe_disable_fake_tensor_mode to fake_tensor.unset_fake_temporarily (#132770)"
This reverts commit 902c6f3a191fb2ecb1976895b3e9eaae4b257b89.

Reverted https://github.com/pytorch/pytorch/pull/132770 on behalf of https://github.com/ezyang due to Removed API was recommitted ([comment](https://github.com/pytorch/pytorch/pull/132770#issuecomment-2275749689))
2024-08-08 12:54:34 +00:00
902c6f3a19 [BE] Reroute all uses of proxy_tensor.maybe_disable_fake_tensor_mode to fake_tensor.unset_fake_temporarily (#132770)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132770
Approved by: https://github.com/bdhirsh
ghstack dependencies: #132674, #132675, #132421, #132062, #132767, #132769
2024-08-08 12:03:25 +00:00
4db368a475 make functorch CSE respect mutations as barriers (like fsdp.set_) (#132243)
Fixes https://github.com/pytorch/pytorch/issues/132200

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132243
Approved by: https://github.com/albanD, https://github.com/zou3519, https://github.com/yf225
2024-08-05 21:28:55 +00:00