Commit Graph

118 Commits

Author SHA1 Message Date
2e0e08588e [BE][PYFMT] migrate PYFMT for torch/[e-n]*/ to ruff format (#144553)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144553
Approved by: https://github.com/ezyang
ghstack dependencies: #144551
2025-06-17 08:18:47 +00:00
f4ac9a160d [fx] Filter stacktrace (#151029)
Filtering out the stacktrace so that the stacktrace on nodes when using fx.Tracer looks nicer. I just copied the filtering we have in [proxy_tensor.py](6720d23969/torch/fx/experimental/proxy_tensor.py (L1903-L1931)).

Previously the stacktrace looked like:
```
File "/data/users/angelayi/pytorch/moo.py", line 3964, in <module>
    run_tests()
  File "/data/users/angelayi/pytorch/torch/testing/_internal/common_utils.py", line 1342, in run_tests
    unittest.main(argv=argv)
  File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/main.py", line 101, in __init__
    self.runTests()
  File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/main.py", line 271, in runTests
    self.result = testRunner.run(self.test)
  File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/runner.py", line 184, in run
    test(result)
  File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/case.py", line 650, in __call__
    return self.run(*args, **kwds)
  File "/data/users/angelayi/pytorch/torch/testing/_internal/common_utils.py", line 3324, in run
    self._run_custom(
  File "/data/users/angelayi/pytorch/torch/testing/_internal/common_utils.py", line 3296, in _run_custom
    super_run(result=result)
  File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/case.py", line 591, in run
    self._callTestMethod(testMethod)
  File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
    method()
  File "/data/users/angelayi/pytorch/torch/testing/_internal/common_utils.py", line 3156, in wrapper
    method(*args, **kwargs)
  File "/data/users/angelayi/pytorch/moo.py", line 1495, in test_stack_trace
    gm = torch.fx.GraphModule(m, tracer.trace(m))
  File "/data/users/angelayi/pytorch/torch/fx/_symbolic_trace.py", line 837, in trace
    (self.create_arg(fn(*args)),),
  File "/data/users/angelayi/pytorch/moo.py", line 1485, in forward
    x = x * 2
  File "/data/users/angelayi/pytorch/torch/fx/proxy.py", line 716, in impl
    return tracer.create_proxy("call_function", target, args, kwargs)
  File "/data/users/angelayi/pytorch/torch/fx/proxy.py", line 248, in create_proxy
    proxy.node.stack_trace = "".join(CapturedTraceback.extract().format())
```
Now it looks like:
```
File "/data/users/angelayi/pytorch/moo.py", line 1485, in forward
    x = x * 2
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151029
Approved by: https://github.com/jfix71, https://github.com/zou3519, https://github.com/jingsh
2025-04-22 22:50:36 +00:00
c41fbb4f78 Change arg_kwarg_vals propagation strategy (#148046)
Instead of always propagating arg_kwarg_vals in _COPY_META_FIELDS, we
special-case the pattern matcher to propagate arg_kwarg_vals when
it sees triton_kernel_wrapper_functional.

The strategy is:
1) trace out the replacement graph with arg_kwarg_vals (which have accurate eager-mode metadata)
2) trace out the replacement graph with vals (which have the accurate Inductor metadata)
3) Propagate the arg_kwarg_vals from the first graph to the second.
4) Use the second graph as the replacement graph.

The strategy is this because we want to extend this to handle
auto_functionalized later up in the stack.

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148046
Approved by: https://github.com/eellison
2025-04-02 13:17:52 +00:00
a60b4ed623 [fx] Optimize TracerBase.create_arg and Graph._gen_python_code (#148292)
Before: 19502951 function calls (18702776 primitive calls) in 8.533 seconds
After: 16402551 function calls (15602452 primitive calls) in 7.701 seconds

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148292
Approved by: https://github.com/oulgen
ghstack dependencies: #148243, #148260, #148261, #148288
2025-03-10 16:06:19 +00:00
bec7bdad47 [fx] Move map_aggregate to C++ (#148243)
Microbenchmarking `fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))`, before:
```
30603618 function calls (29403419 primitive calls) in 13.744 seconds
```
after:
```
25203549 function calls (24403352 primitive calls) in 12.090 seconds
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148243
Approved by: https://github.com/oulgen
2025-03-10 16:05:53 +00:00
92beda54c8 Revert "[fx] Move map_aggregate to C++ (#148243)"
This reverts commit edaff88f69f069d517b72ea23fd5eb04702eb0b5.

Reverted https://github.com/pytorch/pytorch/pull/148243 on behalf of https://github.com/jovianjaison due to breaking internal builds [T216910920] ([comment](https://github.com/pytorch/pytorch/pull/148243#issuecomment-2698724058))
2025-03-04 19:40:21 +00:00
ed9055c303 Revert "[fx] Optimize TracerBase.create_arg and Graph._gen_python_code (#148292)"
This reverts commit 8531d247ba411993f9a10686d70514f6945f9960.

Reverted https://github.com/pytorch/pytorch/pull/148292 on behalf of https://github.com/clee2000 due to something in this stack broke some dynamo and higher order ops tests like higher_order_ops/test_invoke_subgraph.py::TestInvokeSubgraphCompile::test_dedupe [GH job link](https://github.com/pytorch/pytorch/actions/runs/13645082540/job/38149882002) [HUD commit link](8531d247ba).   dynamo/test_graph_deduplication did run on the PR but the higher_order_ops one didn't, probably combo of landrace and bad TD ([comment](https://github.com/pytorch/pytorch/pull/148288#issuecomment-2698365172))
2025-03-04 17:10:12 +00:00
8531d247ba [fx] Optimize TracerBase.create_arg and Graph._gen_python_code (#148292)
Before: 19502951 function calls (18702776 primitive calls) in 8.533 seconds
After: 16402551 function calls (15602452 primitive calls) in 7.701 seconds

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148292
Approved by: https://github.com/oulgen
ghstack dependencies: #148243, #148260, #148261, #148303, #148288
2025-03-04 02:42:23 +00:00
edaff88f69 [fx] Move map_aggregate to C++ (#148243)
Microbenchmarking `fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))`, before:
```
30603618 function calls (29403419 primitive calls) in 13.744 seconds
```
after:
```
25203549 function calls (24403352 primitive calls) in 12.090 seconds
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148243
Approved by: https://github.com/oulgen
2025-03-02 22:42:31 +00:00
c839fa4dd2 [Resubmit] Record input strides at time of tracing, constrain to them for triton fn (#147861)
Resubmit of https://github.com/pytorch/pytorch/pull/145448. it lost its changes on rebase.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147861
Approved by: https://github.com/zou3519
2025-02-26 05:05:06 +00:00
92d448ff62 Add self to CODEOWNERS for fx/proxy.py; warn against adding new node arg types (#147031)
Not sure if there's a better way

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147031
Approved by: https://github.com/StrongerXi
ghstack dependencies: #147016, #147012, #147013
2025-02-13 18:21:21 +00:00
0d6343347f Revert "Record inputs at time of tracing, constrain to them for triton fn (#145448)"
This reverts commit a699034eeca8c096c44a690e405a60efa442d4ed.

Reverted https://github.com/pytorch/pytorch/pull/145448 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally. See D68779678 for details ([comment](https://github.com/pytorch/pytorch/pull/145448#issuecomment-2622470810))
2025-01-29 18:07:12 +00:00
a699034eec Record inputs at time of tracing, constrain to them for triton fn (#145448)
Record input fake tensors at time of tracing and store them in the node meta. Inductor passes have the possibility of changing strides, so it is safer to record the strides of the inputs at tracing. See, https://github.com/pytorch/pytorch/issues/137979 for more context.

We can also extend this to custom ops, and user-visible outputs. If this ends up being compilation time sensitive we can just record strides (and maybe storage offset, per @zou3519) instead of the complete fake tensor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145448
Approved by: https://github.com/zou3519
2025-01-28 07:07:14 +00:00
0b2a3687b9 PEP585 update - torch/fx (#145166)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145166
Approved by: https://github.com/bobrenjc93
2025-01-20 18:11:54 +00:00
9f48881ba8 [BE]: Enable RUF013 ban implicit optional (#141706)
Enables RUF013 rule to ban implicit Optional (from areas not already checked by mypy).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141706
Approved by: https://github.com/ezyang
2024-11-28 04:03:01 +00:00
abbd71d29d [BE][Easy] enable PYFMT for torch.fx (#138443)
Reproduce command:

```bash
ghstack checkout https://github.com/pytorch/pytorch/pull/138443
git checkout HEAD~1 torch/
lintrunner -a --take "PYFMT" --all-files
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138443
Approved by: https://github.com/ezyang
2024-10-21 19:15:49 +00:00
9b01d17b8d Use MetaProxy more pervasively (#137588)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137588
Approved by: https://github.com/ezyang
ghstack dependencies: #136674
2024-10-09 23:22:03 +00:00
36133f39db Tensorify compute on Python scalars (#136674)
Signed-off-by: Bob Ren <bobrenfb.com>

Comandeered from https://github.com/pytorch/pytorch/pull/130228 as I'm helping @ezyang w/ shipping dynamic float arguments in PT2. This starts with supporting torch.ops.aten.mul. I'll stack on top support for other operators in subsequent PRs to keep this scoped to the mechanics of the fx pass.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136674
Approved by: https://github.com/ezyang
2024-10-09 18:51:41 +00:00
a815611db9 [Traceable FSDP2][Partitioner] Must save AC output if output has a backward hook (#135727)
If node is AC region output and has a backward hook on it, we intentionally choose to save it.
This is to work around circular dependencies in Traceable FSDP2+AC.
Example:
```
out = fully_shard(utils.checkpoint(module))(x)
norm_out = layer_norm(out)
```
and there is a circular dependency:
1. In backward, grad_input of layer_norm aka. `out_grad` is actually dependent on `out`.
2. `out` depends on `out`'s backward hook created by FSDP2 (which does all-gather for `module` weights) in order to be recomputed.
3. `out`'s FSDP2 backward hook, as is the case for all eager backward hooks, depends on `out_grad`  -> circular dependency with (1)!

Solution: check whether `out` has a backward hook, and if so, intentionally save `out` in forward graph outputs. With this, we can break the above circular dependency.

----

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135727
Approved by: https://github.com/Chillee
2024-09-14 08:45:58 +00:00
a72124add9 [fx] Minor optimization in create_arg (#135821)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135821
Approved by: https://github.com/oulgen
ghstack dependencies: #135787, #135788, #135820
2024-09-13 00:18:41 +00:00
240467adfe [fx] Implement deepcopy for Proxy (#133706)
Summary: When deepcopy a proxy, we first try the default deepcopy behavior.

Test Plan: buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:fx -- -r  proxy_deepcopy

Differential Revision: D61398418

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133706
Approved by: https://github.com/angelayi
2024-08-22 16:37:30 +00:00
f23dbefe52 [export] Support "custom" metadata field. (#131912)
Summary:
Add a special field in Graph and Node level metadata called "custom" which should be mapped to a json-serializable object, and we guarantee this field should be always preversed across the following transformations:
1. copy/deepcopy
2. run_decompositions()
3. serialization
4. re-exporting

Test Plan: :test_export -- -r custom_tag

Reviewed By: angelayi

Differential Revision: D60291839

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131912
Approved by: https://github.com/angelayi
2024-08-14 01:09:01 +00:00
3c5b246d3c [export] Remove Proxy from exported programs and modules (#132956)
Summary: Remove Proxy from exported programs and modules because they cannot be deepcopied or pickeled.

Test Plan:
CI

```
buck2 run 'fbcode//mode/dev-nosan'  fbcode//caffe2/test/quantization:test_quantization -- -r  qat_conv2d
buck2 run 'fbcode//mode/dev-nosan' fbcode//modai/test:test_modai -- -r test_qat_stinson_htp_export
buck2 run 'fbcode//mode/dev-nosan' fbcode//vizard_projects/ml_depth/tests:test_model -- -r test_qat_model_et
buck2 run 'fbcode//mode/dev-nosan' fbcode//bolt/nn/executorch/backends/tests:qnn_test -- -r test_qat_bias=False,use_3d_input=False
buck2 run 'fbcode//mode/dev-nosan' fbcode//bolt/nn/executorch/backends/tests:qnn_test -- -r test_qat_bias=True,use_3d_input=False
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r  test_fold_bn_erases_bn_node
```

Differential Revision: D60940832

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132956
Approved by: https://github.com/angelayi
2024-08-09 00:00:20 +00:00
aec6332356 Only thunkify proxies in some situations (#132421)
The goal of this PR is to avoid stack overflow when we create extremely long chains of thunks, and then evaluate them (e.g., as occurs if you sum(long list of symint)). The basic idea behind this PR is to only thunkify proxies if they're being created in places where they may or may not be used--crucially, symint operations that occur in user code we are tracing are eagerly placed into the graph, even if they may eventually be dead.

I annotated the PR with explanation of changes.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132421
Approved by: https://github.com/Skylion007, https://github.com/zou3519
ghstack dependencies: #132674, #132675
2024-08-08 12:03:06 +00:00
825002c9c6 [export][fx] More robust DCE pass (#132764)
Summary:
- make default DCE pass check schema,
- need to rebase onto https://github.com/pytorch/pytorch/pull/131651 after it's in phabricator (for now the change is manually added).

- mark Proxy dump as NotImplemented for better error msg

- Remove Proxy from tensors when dumping models, as Proxy cannot be dumped.

More details in https://docs.google.com/document/d/1G5vmTXjzxoyVGRI2kpA1gQukK_Glyg2NrE0Oh6Nlg9A/edit?usp=sharing.

Test Plan:
CI
```
- buck2 run 'fbcode//mode/dev-nosan'  fbcode//caffe2/test/quantization:test_quantization -- -r  qat_conv2d
- test_export.py
- buck2 run 'fbcode//mode/dev-nosan' fbcode//modai/test:test_modai -- -r test_qat_stinson_htp_export
- buck2 run 'fbcode//mode/dev-nosan' fbcode//vizard_projects/ml_depth/tests:test_model -- -r test_qat_model_et
- buck2 run 'fbcode//mode/dev-nosan'  fbcode//caffe2/test:fx -- -r dce
- buck2 run 'fbcode//mode/dev-nosan' fbcode//bolt/nn/executorch/backends/tests:qnn_test -- -r test_qat_bias=False,use_3d_input=False
- buck2 run 'fbcode//mode/dev-nosan' fbcode//bolt/nn/executorch/backends/tests:qnn_test -- -r test_qat_bias=True,use_3d_input=False
- buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r  test_fold_bn_erases_bn_node
```

Reviewed By: angelayi

Differential Revision: D60319175

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132764
Approved by: https://github.com/angelayi
2024-08-06 22:27:22 +00:00
72d2dba992 Add None return type to init (#132335)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132335
Approved by: https://github.com/albanD
2024-08-01 15:26:45 +00:00
58b8704f28 [aot] Keep backward mutations in backward (#129130)
https://github.com/pytorch/pytorch/issues/127561

Mutations of inputs in backward are emitted manually, after joint_fn tracing.
With default partitioner logic they will be moved to "forward" graph, as this is operation on forward inputs.

To keep those mutations in backward:
- Introduce "subgraph" node key, that can be specified with contextmanager. When we do manual `copy_` in backward on forward input - we know that his is for backward - set subgraph="backward"

In partitioner:
Introducing optional argument subgraph, to filter out nodes with specified subgraph (node_subgraph) and not to add them to subgraph if node_subgraph is different.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129130
Approved by: https://github.com/Chillee
2024-07-25 20:02:25 +00:00
df9d1b44e7 Preserve _numeric_debug_handle throguh deepcopy and re-export (#129287)
Summary:
* Added support for preserving it during deepcopy, need to remap the args since _numeric_debug_handle refers
to the nodes in the graph

TODO: need to fully support re-export, currently the metadata for output node is not preserved

Test Plan:
python test/test_quantization.py -k test_deepcopy_preserve_handle
python test/test_quantization.py -k test_copy_preserve_handle

all related tests:
python test/test_quantization.py -k TestGenerateNumericDebugHandle

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129287
Approved by: https://github.com/zhxchen17
2024-07-11 02:19:41 +00:00
575bc1e3af [Reopen #114036] Allow "must recompute" in torch.compile + selective checkpointing (SAC) (#129295)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129295
Approved by: https://github.com/Chillee
2024-06-25 23:47:08 +00:00
e9c6e8369c Torchbind call method + effects support (#128397)
Adds effect token support to torchbind method calls by allowing `with_effects` to take in `torch.ops._higher_order_ops.call_torchbind` as an input.

Here is the print from `TORCH_LOGS="aot" python test/export/test_torchbind.py -k test_compile_obj_torchbind_op`:
```python
def forward(self, arg0_1: "f32[0]", arg1_1: "f32[2]", arg2_1):
    # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1266 in f, code: torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
    cos: "f32[2]" = torch.ops.aten.cos.default(arg1_1)
    with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops._TorchScriptTesting.queue_push.default, arg2_1, cos);  arg0_1 = cos = None
    getitem: "f32[0]" = with_effects[0];  with_effects = None

    # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1267 in f, code: torch.ops._TorchScriptTesting.queue_push(tq, x.cos() + 1)
    cos_1: "f32[2]" = torch.ops.aten.cos.default(arg1_1)
    add: "f32[2]" = torch.ops.aten.add.Tensor(cos_1, 1);  cos_1 = None
    with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.queue_push.default, arg2_1, add);  getitem = add = None
    getitem_2: "f32[0]" = with_effects_1[0];  with_effects_1 = None

    # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1268 in f, code: torch.ops._TorchScriptTesting.queue_pop(tq)
    with_effects_2 = torch._higher_order_ops.effects.with_effects(getitem_2, torch.ops._TorchScriptTesting.queue_pop.default, arg2_1);  getitem_2 = None
    getitem_4: "f32[0]" = with_effects_2[0];  with_effects_2 = None

    # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1269 in f, code: torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
    sin: "f32[2]" = torch.ops.aten.sin.default(arg1_1);  arg1_1 = None
    with_effects_3 = torch._higher_order_ops.effects.with_effects(getitem_4, torch.ops._TorchScriptTesting.queue_push.default, arg2_1, sin);  getitem_4 = sin = None
    getitem_6: "f32[0]" = with_effects_3[0];  with_effects_3 = None

    # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1270 in f, code: return tq.pop(), tq.pop() + tq.size(), tq
    with_effects_4 = torch._higher_order_ops.effects.with_effects(getitem_6, torch.ops._higher_order_ops.call_torchbind, arg2_1, 'pop');  getitem_6 = None
    getitem_8: "f32[0]" = with_effects_4[0]
    getitem_9: "f32[2]" = with_effects_4[1];  with_effects_4 = None
    with_effects_5 = torch._higher_order_ops.effects.with_effects(getitem_8, torch.ops._higher_order_ops.call_torchbind, arg2_1, 'pop');  getitem_8 = None
    getitem_10: "f32[0]" = with_effects_5[0]
    getitem_11: "f32[2]" = with_effects_5[1];  with_effects_5 = None
    with_effects_6 = torch._higher_order_ops.effects.with_effects(getitem_10, torch.ops._higher_order_ops.call_torchbind, arg2_1, 'size');  getitem_10 = arg2_1 = None
    getitem_12: "f32[0]" = with_effects_6[0];  with_effects_6 = None
    add_1: "f32[2]" = torch.ops.aten.add.Tensor(getitem_11, 0);  getitem_11 = None
    return (getitem_12, getitem_9, add_1)
```

In order to support this, this PR makes the following changes:
* Adds `FakeScriptObject` to `CustomObjArgument`, which will be put on the `meta["val"]` of nodes representing torchbind objects.
* Adds pickle/deepcopy support to FunctionSchema.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128397
Approved by: https://github.com/ydwu4, https://github.com/zou3519
2024-06-14 21:28:17 +00:00
ed457c7dbe [export] Add torch_fn (#122693)
This PR adds a new metadata, `torch_fn` which is meant to replace `source_fn_stack` as `source_fn_stack` is not entirely well defined between strict/nonstrict. Previous discussion [here](https://docs.google.com/document/d/1sPmmsmh6rZFWH03QBOe49MaXrQkP8SxoG8AOMb-pFk4/edit#heading=h.anmx9qknhvm).

`torch_fn` represents the torch function that a particular aten operator came from. For example, `torch.nn.Linear` goes down to the `torch.nn.functional.linear` at the `__torch_function__` layer, and then `aten.t/aten.addmm` in the `__torch_dispatch__` layer. So the nodes `aten.t/aten.addmm` will now have the `torch_fn` metadata containing the `torch.nn.functional.linear`.

The `torch_fn` metadata is a tuple of 2 strings: a unique identifier for each torch function call, and the actual torch function `f"{fn.__class__}.{fn.__name__}"`. The purpose of the first value is to distinguish between 2 consecutive calls to the same function. For example, if we had 2 calls to `torch.nn.Linear`, the nodes and corresponding metadata would look something like:
```
aten.t - ("linear_1", "builtin_function_or_method.linear"),
aten.addmm - ("linear_1", "builtin_function_or_method.linear"),
aten.t - ("linear_2", "builtin_function_or_method.linear"),
aten.addmm - ("linear_2", "builtin_function_or_method.linear"),
```

Higher order ops -- currently we can get the torch_fn metadata for nodes within the HOO's subgraph, but after retracing, this becomes the `(cond, higher_order_op.cond)` :( This is because `fx_traceback.set_current_meta` points to the cond node in the toplevel graph, rather than the original node in the subgraph. I think this is because `fx.Interpreter` does not go into the cond subgraphs. (will discuss with Yidi more ab this)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122693
Approved by: https://github.com/tugsbayasgalan
2024-03-30 06:47:15 +00:00
1d96791661 [dynamo] Fix list proxy to list element proxy source propagation (#122691)
Currently, when we create proxies for a list's elements in wrap_fx_proxy_cls, we create them using the same source as the list's e.g. `LocalSource(inputs)` instead of `GetItemSource(LocalSource(inputs), index=i)`. This results in invalid guards when the tensors it contains becomes dynamic, and the guard system thinks the list is a tensor:
```
Malformed guard:
L['sizes'][0] == L['inputs'].size()[0]
Malformed guard:
2 <= L['inputs'].size()[0]

Traceback [...]
AttributeError: 'list' object has no attribute 'size'
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122691
Approved by: https://github.com/jansel, https://github.com/anijain2305
2024-03-28 14:40:54 +00:00
7c5e29ae71 Back out "Support triton.language.dtype with torch.compile (#121690)" (#122108)
Summary: Some hard to deal with package import/export related problems. Lets revert and start with clean slate.

Test Plan: CI

Differential Revision: D55024877

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122108
Approved by: https://github.com/ezyang
2024-03-18 20:50:28 +00:00
65ccac6f17 Fix triton import time cycles (#122059)
Summary: `has_triton` causes some import time cycles. Lets use `has_triton_package` which is enough.

Test Plan:
```
buck2 test 'fbcode//mode/opt' fbcode//fblearner/flow/projects/model_processing/pytorch_model_export_utils/logical_transformations/tests:filter_inference_feature_metadata_test -- --exact 'fblearner/flow/projects/model_processing/pytorch_model_export_utils/logical_transformations/tests:filter_inference_feature_metadata_test - test_collect_features_from_graph_module_nodes (fblearner.flow.projects.model_processing.pytorch_model_export_utils.logical_transformations.tests.filter_inference_feature_metadata_test.FilterInferenceFromFeatureMetadataTest)'
```
now passes

Differential Revision: D55001430

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122059
Approved by: https://github.com/aakhundov
2024-03-18 05:50:32 +00:00
02bb2180f4 [torch export] replace traceback.extract_stack with CapturedTraceback.extract (#121449)
Summary:
with a simple bench in TestDeserializer.test_basic function:
```
time_start = time.time()
for i in range(1000):
    self.check_graph(MyModule(), inputs)
warnings.warn(f"time_taken: {time.time() - time_start}")
```
and forcing FakeTensorConfig.debug to True, record_stack_traces to True, logging level to debug, it shows that the the changed code is consistently ard 20 secs faster (~90s vs originally ~110s)

Test Plan:
test passed, see summary

compared debug trace before and after:
- exactly the same for fake tensor and proxy callsite https://www.internalfb.com/intern/diffing/?paste_number=1189883685
- slightly different for the user frame in proxy node https://www.internalfb.com/intern/diffing/?paste_number=1189884347

Differential Revision: D54237017

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121449
Approved by: https://github.com/angelayi
2024-03-13 00:19:05 +00:00
79ee6bbde3 Support triton.language.dtype with torch.compile (#121690)
Putting this PR as an RFC since I have resorted to some horrible hacks in order to make this work.
```
(Pdb) p triton.language.float32
triton.language.fp32
(Pdb) p str(triton.language.float32)
'fp32'
(Pdb) p repr(triton.language.float32)
'triton.language.fp32'
```
This means that we need to "rewrite" them for fx graph and inductor execution.

This PR allows Mamba2 to work with `torch.compile`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121690
Approved by: https://github.com/Skylion007
2024-03-12 23:21:46 +00:00
9bce208dfb Replace follow_imports = silent with normal (#118414)
This is a lot of files changed! Don't panic! Here's how it works:

* Previously, we set `follow_imports = silent` for our mypy.ini configuration. Per https://mypy.readthedocs.io/en/stable/running_mypy.html#follow-imports, what this does is whenever we have an import to a module which is not listed as a file to be typechecked in mypy, we typecheck it as normal but suppress all errors that occurred in that file.
* When mypy is run inside lintrunner, the list of files is precisely the files covered by the glob in lintrunner.toml, but with files in excludes excluded.
* The top-level directive `# mypy: ignore-errors` instructs mypy to typecheck the file as normal, but ignore all errors.
* Therefore, it should be equivalent to set `follow_imports = normal`, if we put `# mypy: ignore-errors` on all files that were previously excluded from the file list.
* Having done this, we can remove the exclude list from .lintrunner.toml, since excluding a file from typechecking is baked into the files themselves.
* torch/_dynamo and torch/_inductor were previously in the exclude list, because they were covered by MYPYINDUCTOR. It is not OK to mark these as `# mypy: ignore-errors` as this will impede typechecking on the alternate configuration. So they are temporarily being checked twice, but I am suppressing the errors in these files as the configurations are not quite the same. I plan to unify the configurations so this is only a temporary state.
* There were some straggler type errors after these changes somehow, so I fixed them as needed. There weren't that many.

In the future, to start type checking a file, just remove the ignore-errors directive from the top of the file.

The codemod was done with this script authored by GPT-4:

```
import glob

exclude_patterns = [
    ...
]

for pattern in exclude_patterns:
    for filepath in glob.glob(pattern, recursive=True):
        if filepath.endswith('.py'):
            with open(filepath, 'r+') as f:
                content = f.read()
                f.seek(0, 0)
                f.write('# mypy: ignore-errors\n\n' + content)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118414
Approved by: https://github.com/thiagocrepaldi, https://github.com/albanD
2024-01-27 02:44:11 +00:00
b8bd3bb30a Fix aot_autograd seq_nr logic (#118249)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118249
Approved by: https://github.com/zou3519
ghstack dependencies: #117552, #118234
2024-01-25 22:56:20 +00:00
3fe437b24b [BE]: Update flake8 to v6.1.0 and fix lints (#116591)
Updates flake8 to v6.1.0 and fixes a few lints using sed and some ruff tooling.
- Replace `assert(0)` with `raise AssertionError()`
- Remove extraneous parenthesis i.e.
  - `assert(a == b)` -> `assert a == b`
  - `if(x > y or y < z):`->`if x > y or y < z:`
  - And `return('...')` -> `return '...'`

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116591
Approved by: https://github.com/albanD, https://github.com/malfet
2024-01-03 06:04:44 +00:00
a0be4b7ea7 [fx] Update symbolic_trace nn_module_stack (#114422)
Summary:
Fixed nn_module_stack dynamo produced by symbolic trace to align with the nn_module_stack metadata produced by dynamo. The key should be the module path, with the value being a unique name, and the type. Something like: `{'L__self___one_module': ("L['self'].one_module", <class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>)}`

This was causing some tests to fail when using export + the old quantization flow (prepare_fx calls symbolic_trace).

Test Plan: D51534471 `buck2 run @//mode/dev-nosan //executorch/backends/xnnpack/test:test_xnnpack_quantized -- -r "test_xnnpack_leaky_relu"`

Differential Revision: D51539118

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114422
Approved by: https://github.com/JacobSzwejbka, https://github.com/jerryzh168
2023-11-28 00:18:41 +00:00
473b17c4c1 Run sympy expressions with Python values / FX tracing (#113978)
To codegen deferred runtime asserts, I need to be able to convert sympy expressions back into regular Python expressions that I can put in FX graphs. This PR adds some of the machinery to do this: it adds a new sympy analysis that runs operations on all FX traceable operations that can also be run with plain Python int/float/bool/etc. It's tested by symbolic tracing through the analysis, and then testing that this traced graph gives the same result as running the Python analysis directly.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113978
Approved by: https://github.com/aakhundov, https://github.com/lezcano
2023-11-20 21:25:11 +00:00
9e2af971fc [Quantization] Add "quantization_tag" as metadata to fx proxy (#108764)
Summary:
In order to make sure that quantization_tag is preserved through second
stage export, this PR adds it as a special metadata that should be
preserved.

Since quantization in export path will work on top of pre dispatch
graph, subsequent post dispatch op decomposition, will decompose ops
that quant workflow tagged. In order to make sure that the patterns
identified by quantizer, remains identifiable, even after decompositions
are applied, we must preserve "quantization_tag".

This enables backend delegates, that quantized a model for specific
backend, to be able to identify "quantized" patterns.

Test Plan:
metadata porting tests

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D49056259](https://our.internmc.facebook.com/intern/diff/D49056259)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108764
Approved by: https://github.com/tugsbayasgalan, https://github.com/jerryzh168
2023-11-01 21:41:58 +00:00
00d962631c [BE] Enable Ruff's Flake8 PYI045 (#111184)
Enable [iter-method-return-iterable (PYI045)](https://docs.astral.sh/ruff/rules/iter-method-return-iterable/#iter-method-return-iterable-pyi045)

Link: #110950
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111184
Approved by: https://github.com/Skylion007
2023-10-13 22:20:04 +00:00
5f7eff0adb Replace node.meta source_fn with source_fn_stack (#108595)
A resubmit of https://github.com/pytorch/pytorch/pull/108447. Copy over the descriptions:

This is a follow-up of the discussion in https://github.com/pytorch/pytorch/pull/108356, where we want to repalce source_fn with source_fn_stack

Before this PR, for the following example:
```python
backend = EagerAndRecordGraphs()

@torch.compile(backend=backend, fullgraph=True)
def cond_f(pred, pred2, x, y):
    def true_fn(pred2, x, y):
        return x + y

    def false_fn(pred2, x, y):
        def true_fn2(x, y):
            return x.sin() - y.cos()

        def false_fn2(x, y):
            return x.cos() - y.sin()

        return control_flow.cond(pred2, true_fn2, false_fn2, (x, y))

    return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y))
```
The graph captured is shown below:
```python
class GraphModule(torch.nn.Module):
    def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
        l_pred_ = L_pred_
        l_pred2_ = L_pred2_
        l_x_ = L_x_
        l_y_ = L_y_

        cond_true_1 = self.cond_true_1
        cond_false_1 = self.cond_false_1
        cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]);  l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None
        return (cond,)

    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            add = l_x_ + l_y_;  l_x_ = l_y_ = None
            return add

    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            cond_true_0 = self.cond_true_0
            cond_false_0 = self.cond_false_0
            cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]);  l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None
            return cond

        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                sin = l_x_.sin();  l_x_ = None
                cos = l_y_.cos();  l_y_ = None
                sub = sin - cos;  sin = cos = None
                return sub

        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                cos = l_x_.cos();  l_x_ = None
                sin = l_y_.sin();  l_y_ = None
                sub = cos - sin;  cos = sin = None
                return sub
```
the source_fn for inner cond, sin, cos will be a (name, target) tuple:
```
('cond', <torch._ops.HigherOrderOperator object at xxx>)
('sin', 'sin')
('cos', 'cos')
('sub'. <built-in function sub>)
```

After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list.
```
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')]
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)]
```

Test Plan:
See added tests in test_higher_order_ops.py and modify existing test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108595
Approved by: https://github.com/angelayi, https://github.com/zou3519
2023-09-28 18:18:36 +00:00
6e3a7473cf Trace calls with Python Enum values. (#109507)
Fix: #82135
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109507
Approved by: https://github.com/ezyang
2023-09-20 22:18:11 +00:00
3cc5c42a23 Fix aot sequence_nr to reset bwd flag (#107210)
The way the aot autograd sequence_nr tracking works is that we run the aot export logic, the dynamo captured forward graph is run under an fx.Interpreter, which iterates through the nodes of the forward graph while setting the `current_metadata`.
Since during backward what is run doesn't correspond to any node during forward, we fallback to the global `current_metadata`. And since this global metadata is ends up being shared between runs, that leads to weirdness if we forget to reset things, e.g., depending whether this is the first test run, the printed results will be different.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107210
Approved by: https://github.com/bdhirsh
2023-08-24 16:58:12 +00:00
4c46ea583f [Export] Support re-exportability (#106531)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106531
Approved by: https://github.com/zhxchen17
2023-08-03 18:27:26 +00:00
9ba0558d48 Add sequence_nr to aot_autograd to map forward ops to their corresponding backward ops (#103129)
Fixes #102375

Sequence_nr increments in the forward pass and decrements in the backward pass.  Backward ops with the same sequence_nr as a forward op represent the backward implementation for the op.  The long term goal is to make this information available to the profiler so users can observe which ops are fused by the inductor openai triton kernels.

Added a test for this feature **test/dynamo/test_aot_autograd.py::AotAutogradFallbackTests::test_aot_sequence_nr**.  The test case uses **aot_export_module()** to create a joint fwd/bwd fx graph.  Then it walks all the nodes in fx graph using fx_graph.graph.nodes.   The seq_nr of each node is recorded in node.meta.  During the fwd pass the seq_nr increments and it decrements during the bwd pass.  This allows the user to map forward ops to their corresponding bwd ops which is useful for performance analysis.

Expected output from the test case.

 SeqNr|OrigAten|SrcFn
0|aten.convolution.default|l__self___conv1
0|aten.add.Tensor|l__self___bn1
1|aten._native_batch_norm_legit_functional.default|l__self___bn1
2|aten.relu.default|l__self___relu1
3|aten.add.Tensor|add
4|aten.view.default|flatten
5|aten.t.default|l__self___fc1
6|aten.unsqueeze.default|l__self___fc1
7|aten.mm.default|l__self___fc1
8|aten.squeeze.dim|l__self___fc1
9|aten.add.Tensor|l__self___fc1
10|aten.sub.Tensor|l__self___loss_fn
11|aten.abs.default|l__self___loss_fn
12|aten.mean.default|l__self___loss_fn
12|aten.ones_like.default|
12|aten.expand.default|
12|aten.div.Scalar|
11|aten.sgn.default|
11|aten.mul.Tensor|
8|aten.unsqueeze.default|
7|aten.t.default|
7|aten.mm.default|
7|aten.t.default|
7|aten.t.default|
7|aten.mm.default|
6|aten.squeeze.dim|
5|aten.t.default|
4|aten.view.default|
2|aten.threshold_backward.default|
1|aten.native_batch_norm_backward.default|
0|aten.convolution_backward.default|
0|aten.add.Tensor|

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103129
Approved by: https://github.com/soulitzer
2023-08-02 00:52:52 +00:00
a44f8894fa [Inductor] Provenance tracking for wrapper code (#105717)
Summary:
Add comments in wrapper code for better provenance tracking

Sample inductor wrapper output:
```
# Source Nodes: [mm_1], Original ATen: [aten.mm]
extern_kernels.mm(as_strided(tangents_1, (500, 20), (1, 500)), view, out=buf1)

# Source Nodes: [l__self___linear], Original ATen: [aten.addmm]
extern_kernels.addmm(primals_2, as_strided(primals_3, (20, 500), (500, 1)), as_strided(primals_1, (500, 500), (1, 500)), alpha=1, beta=1, out=buf0)
```

in cpp wrapper
```
        // Source Nodes: [bmm_1], Original ATen: bmm
        at::bmm_out(buf0, arg0_1, arg1_1);
```

Test Plan: OSS CI

Differential Revision: D47657260

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105717
Approved by: https://github.com/desertfire, https://github.com/jansel
2023-07-21 23:06:43 +00:00
133c5ec997 Add torch.ops.out_dtype (#103333)
https://docs.google.com/document/d/10DYFG2sU3TSvguFP5kYwYLlo45KHFg3BhBOkUk0NKsU/edit#bookmark=id.hgfzmhlzkamk

Renamed mixed_dtype --> out_dtype because "mixed_dtype is not very descriptive in the context of regular pytorch where we support type promotion on most ops"

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103333
Approved by: https://github.com/zou3519
2023-07-18 16:25:45 +00:00