Commit Graph

47 Commits

Author SHA1 Message Date
b11593c31b [8/N] Apply ruff UP035 rule (#165214)
This is follow-up of #164653 to continue applying `UP035` fixes. The purpose is to finally enable this rule.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165214
Approved by: https://github.com/ezyang
2025-10-15 03:18:57 +00:00
7da02bf8af Skip const folding with symbolic expression (#161437)
Summary: When performing constant folding, we must skip over operators that have symbolic `fill_value`.

Test Plan:
CI

Rollback Plan:

Reviewed By: kalpit-meta-1

Differential Revision: D80965936

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161437
Approved by: https://github.com/StellarrZ
2025-08-27 22:09:58 +00:00
2d98a1caf5 [MTIA] Map names to operand indices when folding submodules (#150692)
When replacing placeholders with getattrs during constant folding, we can have an argument and parameter name mismatch. In fact, there is no guarantee that the parameter name is equivalent to the argument name used in the module call.

Differential Revision: D72415970

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150692
Approved by: https://github.com/jfix71
2025-04-06 03:11:14 +00:00
0b2a3687b9 PEP585 update - torch/fx (#145166)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145166
Approved by: https://github.com/bobrenjc93
2025-01-20 18:11:54 +00:00
abbd71d29d [BE][Easy] enable PYFMT for torch.fx (#138443)
Reproduce command:

```bash
ghstack checkout https://github.com/pytorch/pytorch/pull/138443
git checkout HEAD~1 torch/
lintrunner -a --take "PYFMT" --all-files
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138443
Approved by: https://github.com/ezyang
2024-10-21 19:15:49 +00:00
13b0baf2a1 [FX] Update _inline_module util function to work with both args and kwargs (#136631)
Summary: Previously `_inline_module ` helper function only works with submodules that have args specified. This diff updates the util function to look for input arguments from submodule kwargs first using placeholder node names, then fallback to list of args if node name not found.

Test Plan:
```
buck2 run @//mode/{opt,mtia,inplace} //glow/fb/fx/fba/tests:test_fba_inductor -- -r test_connected_fusions
```

Differential Revision: D63347675

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136631
Approved by: https://github.com/jfix71
2024-09-25 20:20:57 +00:00
5ec9c0bc4a Fix linearize(grad(...)) call (#133364)
Fixes #124550

Also moves `graph.eliminate_dead_code()` call to a few lines after
`_inline_module(...)` in `const_fold.py`

* Test plan:

Add a new test on `test_eager_transforms.py` to ensure the reported
issue was indeed fixed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133364
Approved by: https://github.com/zou3519
2024-08-15 17:55:36 +00:00
d2f44eabe7 [Export] Support aten.full.default and aten.full_like.default (#130639)
Summary: Add operator tests for full & full_like operators

Test Plan:
Rerun kernel test using
```
buck2 run //glow/fba/tests:run_kernel mode/dev -- --kernel splat --config "input=1;dtype=fp32;fill_value=42.0"  -tl_time
```
{F1752274071}

Operator tests
```
buck2 run mode/{opt,inplace} //caffe2/torch/fb/test_library:afg_operator_test -- -k __full__
```
{F1752340913}

Differential Revision: D59593849

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130639
Approved by: https://github.com/StellarrZ
2024-07-16 16:50:04 +00:00
7c2058338a Improve convert fp32 to fp16 fx pass (#127829)
Summary: Improve the convert fp32 to fp16 fx pass to use to_dtype node and const folding instead of inplace conversion.

Test Plan:
```
buck2 test @//mode/{opt,inplace} //glow/fb/fx/fba/tests:test_fba_pass_manager_builder
```

Differential Revision: D57803843

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127829
Approved by: https://github.com/Skylion007
2024-06-12 02:50:37 +00:00
038b927590 Flip default value for mypy disallow_untyped_defs [7/11] (#127844)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127844
Approved by: https://github.com/oulgen
ghstack dependencies: #127842, #127843
2024-06-08 18:49:45 +00:00
d156cb2e12 Fix mem size mismatch from split/chunk in const folding (#125199)
Summary:
The chunk/split ops on the weights/constants is folded in a fx pass and each output tensor has the same storage size of the original tensor (which is 3x of its actual size if chunk(3)). However Backend calculates the mem size on device from tensor shape/stride/dtype. This causes the mismatch  when copying weights/constants to device as allocated mem on device is always smaller than the size of weights/constants and results in a runtime error in loading weight/constant (T172125529).

This diff fixes the issue by cloning the tensors after const folding so that the tensors has correct storage size.

Test Plan:
Before this change: (18432 = 48 * 64 * 2 * 3)
 ```
RuntimeError: Failed to load constant getitem_idx0 split (remaining=18432) at fbcode/caffe2/torch/fb/acc_runtime/afg/afg_bindings.cpp:3422: Request failed because an invalid parameter
```

```
buck2 run mode/opt //caffe2/torch/fb/acc_runtime/afg/tests:test_operators-artemis -- -r test_mem_size_mismatch
```
```
Ran 1 test in 7.048s

OK
```

Reviewed By: jfix71

Differential Revision: D56663931

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125199
Approved by: https://github.com/jfix71
2024-05-03 04:42:38 +00:00
4f5785b6b3 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 21:07:01 +00:00
40ece2e579 Revert "Enable possibly-undefined error code (#118533)"
This reverts commit 4f13f69a45ef53747e2eefffd65d91ce840b431b.

Reverted https://github.com/pytorch/pytorch/pull/118533 on behalf of https://github.com/clee2000 due to sorry i'm trying to figure out a codev merge conflict, if this works i'll be back to rebase and merge ([comment](https://github.com/pytorch/pytorch/pull/118533#issuecomment-1917695185))
2024-01-30 19:00:34 +00:00
4f13f69a45 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 05:08:10 +00:00
5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00
15fd1ea118 Revert "[Reland] Update mypy to 1.4.1 (#105227)"
This reverts commit c9c4f8efc3dd4e66059522bf5f5c1ba0431e2069.

Reverted https://github.com/pytorch/pytorch/pull/105227 on behalf of https://github.com/atalman due to trying to mitigate ci sev #105248 ([comment](https://github.com/pytorch/pytorch/pull/105227#issuecomment-1636510935))
2023-07-14 22:28:35 +00:00
c9c4f8efc3 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-14 20:45:12 +00:00
3c5a494d7a Revert "Update mypy to 1.4.1 (#91983)"
This reverts commit 634659e262f82bbc76aa776119c9fea079fbffe3.

Reverted https://github.com/pytorch/pytorch/pull/91983 on behalf of https://github.com/malfet due to It's dependent change was reverted, so reverting this one as well, to keep CI clean ([comment](https://github.com/pytorch/pytorch/pull/91983#issuecomment-1636059709))
2023-07-14 15:59:16 +00:00
634659e262 Update mypy to 1.4.1 (#91983)
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  -
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91983
Approved by: https://github.com/kit1980, https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/thiagocrepaldi, https://github.com/aaronenyeshi
2023-07-13 16:30:36 +00:00
a475ea4542 [fx] change from #users to num_users in graph printout (#101140)
`#users` means stuff in various chat apps, which makes it annoying to copypasta graphs into them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101140
Approved by: https://github.com/ezyang
2023-06-20 21:24:32 +00:00
66eef31444 Revert "[fx] change from #users to num_users in graph printout (#101140)"
This reverts commit e568c5a18d0fb390437912e13c96e29697af27ec.

Reverted https://github.com/pytorch/pytorch/pull/101140 on behalf of https://github.com/jeanschmidt due to There are internal changes to this commit that are preventing landing, so I am reverting to unblock the diff train ([comment](https://github.com/pytorch/pytorch/pull/101140#issuecomment-1547989487))
2023-05-15 14:35:22 +00:00
e568c5a18d [fx] change from #users to num_users in graph printout (#101140)
`#users` means stuff in various chat apps, which makes it annoying to copypasta graphs into them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101140
Approved by: https://github.com/ezyang
2023-05-12 04:34:01 +00:00
4f3858c6d8 [functorch] linearize (#94173)
Fixes https://github.com/pytorch/functorch/issues/724

TODO:
* [x] Docs

NOTE: `const_fold` pass raises UserWarning -> https://github.com/pytorch/pytorch/issues/94374

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94173
Approved by: https://github.com/Chillee
2023-02-09 15:45:08 +00:00
e0e4f1a890 Revert "[functorch] linearize (#94173)"
This reverts commit b6b9e1e6e043ae4b9f41fbbee4f2a9e9a7e7d3d7.

Reverted https://github.com/pytorch/pytorch/pull/94173 on behalf of https://github.com/kshitij12345 due to Broke lint runner
2023-02-09 09:22:39 +00:00
b6b9e1e6e0 [functorch] linearize (#94173)
Fixes https://github.com/pytorch/functorch/issues/724

TODO:
* [x] Docs

NOTE: `const_fold` pass raises UserWarning -> https://github.com/pytorch/pytorch/issues/94374

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94173
Approved by: https://github.com/Chillee
2023-02-09 08:57:05 +00:00
4dc579838b Allow fx.Graph.owning_module to be used as attribute. (#86822)
Summary:
The current behavior of owning_module setter is difficult to understand: it changes the owning_module to None if owners is not 0 but increments the owners count. If the owning_module is None, the owners count should be 0 as none of them is accessible. On the other hand, if the owners count increases, the owning_module should be a collection (e.g. a list).

This diff changes owning_module to be a normal attribute. The semantic is that graph can have **at most one** owning module and can be assigned to new module.

The alternative is to use a list to represent the owning_modules of a graph but it breaks backward compatibility and the exact use cases of having multiple owning_modules are not clear.

Test Plan: Test with CI.

Differential Revision: D40200624

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86822
Approved by: https://github.com/tugsbayasgalan
2022-10-19 00:12:59 +00:00
c00f0c80c0 [fx] add deferred weights (xl_weight) and tracing for xl_embedding_bag (#84016)
Test Plan: added unit tests

Reviewed By: jfix71

Differential Revision: D36152238

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84016
Approved by: https://github.com/jfix71
2022-08-25 07:10:30 +00:00
d4bbebbb97 Back out "Back out "[const_fold] Set requires_grad based on the folded tensor; add device_for_folding option"" (#79696)
Summary:
This is an un-backout but with a small change to set the default device `device_for_folded_attrs="cuda"` instead of `"cpu"`, which should avoid BC issues for TRT lowering.

Original commit changeset: 4ae1863e28ff

Original Phabricator Diff: D37192230 (24c2aff1b2)

Test Plan: Added unit test

Differential Revision: D37205432

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79696
Approved by: https://github.com/dborkovic
2022-06-18 19:13:49 +00:00
e627dd542d Revert "Back out "[const_fold] Set requires_grad based on the folded tensor; add device_for_folding option" (#79655)"
This reverts commit 24c2aff1b23729b5ded089ffbc81ac3239fc47bd.

Reverted https://github.com/pytorch/pytorch/pull/79655 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally
2022-06-17 22:09:48 +00:00
24c2aff1b2 Back out "[const_fold] Set requires_grad based on the folded tensor; add device_for_folding option" (#79655)
Reviewed By: yinghai, khabinov

Differential Revision: D37192230

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79655
Approved by: https://github.com/khabinov, https://github.com/yinghai
2022-06-16 06:51:53 +00:00
f614f66acf [const_fold] Set requires_grad based on the folded tensor; add device_for_folding option (#79067)
Summary: att

Test Plan: Added unit test coverage for tensor_meta part.

Reviewed By: wushirong

Differential Revision: D36975932

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79067
Approved by: https://github.com/dborkovic
2022-06-14 19:00:05 +00:00
cf66d42012 Add mapping for unsqueeze_n_times (#75043)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/75043

Reviewed By: RoshanPAN, garroud

Differential Revision: D35293706

fbshipit-source-id: 2fe1f6ee54376ef9796ae66cb4d31533ea92bb42
(cherry picked from commit bf3c02f760ffda0d648d4e95ef2e0e381af4cdaa)
2022-04-01 05:02:56 +00:00
32dd4a8639 move fx_acc out of pytorch core (#72803)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72803

as title

Reviewed By: jfix71

Differential Revision: D34101788

fbshipit-source-id: a9fd84671929af21405c049603e9895ec68de3d8
(cherry picked from commit e98fd1c32d34d26a3a4e8ad8ace041735312422e)
2022-02-15 16:13:43 +00:00
68d8ab0cc6 [const_fold] Fix call_module const folding (#68614)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68614

We need to copy modules over to the `split` graph during const folding. We were previously only doing so from the non-constant submod, but we need to do this for the constant one as well in case some `call_module` is const folded.

Test Plan: Added unit test

Reviewed By: wushirong, 842974287

Differential Revision: D32543289

fbshipit-source-id: 80d1d0ce2c18a665b00e1343d6c55d939390ab10
2021-11-18 20:56:06 -08:00
1c0d6ff835 [fx][const fold] Allow to set up a function to modify const_nodes for split_const_subgraphs (#67784)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67784

FX model generates quant/dequant layers for INT8 explicit mode support. However, if the inputs of quant/dequant layers are constant, the layer will be put into constant subgraph and optimized out. Hence TensorRT will fails to parse the left over graph. It is better to set up an optional function (skip_folding_node_fn) to skip folding nodes for split_const_subgraphs.

Reviewed By: jfix71

Differential Revision: D32076970

fbshipit-source-id: 7dcbb4f02386f8c831d09a2f0e40bcdba904471c
2021-11-15 06:51:19 -08:00
e5f6f356da [hpc infer] fix bench perf number
Reviewed By: yinghai, jianyuh

Differential Revision: D31505288

fbshipit-source-id: e4951a7c5813e0ee38903dec4cef61531f1b4059
2021-10-08 19:11:04 -07:00
592481a5cc [fx][const_fold] Refactor to use base split module to simplify, and correctly handle non-single-Tensor outputs (#65933)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65933

We use `split_module` to split the input model that we want to const fold into const and non-const subgraphs. Previously we were taking the non-const graph and trying to hack it back into the same signature as the input model. However this was complex/buggy.

Instead, refactor to just keep using the base split module that contains both const and non-const graphs. This means we:
- Inline the non-const graph into the split module
- Remove the const graph from the module and replace it with a getattr that will be run to insert that attr when we `run_folding`

Test Plan: Added test coverage to cover newly supported folding, and updated other tests for new strategy.

Reviewed By: yinghai

Differential Revision: D31293307

fbshipit-source-id: 6e283a8c7222cf07b14e30e74dffc8ae5ee8b55f
2021-10-01 10:26:29 -07:00
31f61122da Creating a helper function to generate an unique name for an attr in a module (#64970)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64970

Add a helper function to create an unique name for an attr.
This can be used when we want to add a weight to a module.

Test Plan: run CI.

Reviewed By: jfix71

Differential Revision: D30921497

fbshipit-source-id: 598569d107df8b516ff12920a4bef3a42577e987
2021-09-20 14:35:56 -07:00
19471c54a6 [fx const fold] fix a case when some inputs are unused (#65223)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65223

If there're unused inputs, they won't appear in `submod_1`. We need to add all the unused inputs so that the model after const fold has the same inputs as the original model.

Reviewed By: jfix71

Differential Revision: D31021217

fbshipit-source-id: b7452c90d133b747e0699936a81d3fee14af9cc9
2021-09-17 17:29:55 -07:00
f9c0a39ad9 add a test case for const fold (#65224)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65224

Add a test case for the fix D30996277 (8c38d141df).

Test Plan: buck test mode/opt -c python.package_style=inplace -c fbcode.nvcc_arch=v100,a100 -c fbcode.enable_gpu_sections=true -j 40 caffe2/test:fx_const_fold -- test_const_fold_module_attr

Reviewed By: jfix71

Differential Revision: D31000386

fbshipit-source-id: f444361839decc583bf93ac946cfe2049376719e
2021-09-17 10:32:07 -07:00
8c38d141df Add back the owning_module fix (#65159)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65159

This was a legit fix originally introduced in D30905949 (446d95a7f6). But we hesitated and removed it for some reason. Putting it back.

Reviewed By: 842974287

Differential Revision: D30996277

fbshipit-source-id: 3f5eede11dba2072e7cd5ae6ca7ac81d55fb75fa
2021-09-16 19:29:56 -07:00
446d95a7f6 [fx const fold] fix some cases with deep model hierarchy (#64945)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64945

In the const folding pass, we try to create `get_attr` nodes in submod_1 for `get_attr` nodes that are in the main graph. But we don't have the real attributes in submod_1. To fix this we assign main module as the owning module of sumod_1 graph.

The fix above would cause problem for `call_module` node in submod_1 because during split modules gets inlined (target changed from "mod.a.b" -> "mod_a_b") to submod_1. Changing the owning module would make those `call_module nodes unable to find the referring module. To fix this, we set the targeting module to main module.

Reviewed By: jfix71

Differential Revision: D30905949

fbshipit-source-id: cd67bc8fe4b8ad4344ae97b8e36753fdce3ece6d
2021-09-14 09:45:44 -07:00
be091950d0 [const_fold] Keep around node.meta for replaced folded ops (#64782)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64782

Previously, get_attrs that were added to the graph did not retain node.meta after folding. Add such support, and improve coverage in general here.

Test Plan: Added test coverage.

Reviewed By: protonu

Differential Revision: D30852704

fbshipit-source-id: ece87a61c69b2e68982964c6adc4dde14dae12c7
2021-09-09 23:52:39 -07:00
b12150608e [fx] make const fold code more pythonic (#64451)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64451

No functional change.

Test Plan:
```
buck test caffe2/test:fx_const_fold
```

Reviewed By: jfix71, RoshanPAN, houseroad

Differential Revision: D30718255

fbshipit-source-id: 95f98561c7f33fcc6c839db68683c85eb152c949
2021-09-08 13:55:10 -07:00
2341ec9ef1 [fx_const_fold] Fix constant folding for attrs in submodule hierarchies (#64342)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64342

Previously we weren't handling the case where an attribute was in a module that wasn't the root.

Test Plan: Added unit test coverage.

Reviewed By: yinghai

Differential Revision: D30691730

fbshipit-source-id: b39b5cf748c4c882f315a4f32b51ad88cc7a43ed
2021-09-07 22:44:39 -07:00
a1c5eba4bd [FX] Move some heavily used passes out of experimental (#51392)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51392

Test Plan: Imported from OSS

Reviewed By: Chillee

Differential Revision: D26161172

Pulled By: jamesr66a

fbshipit-source-id: 04bfe606555bdf1988f527231d4de2e0196e6b37
2021-02-01 19:02:26 -08:00
38ed398580 [fx] Add constant folding pass (#48443)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48443

Add a constant folding pass in FX:
- Iterate over an input graph and tag what nodes are fully constant, i.e. either `get_attr` nodes, or nodes with all inputs that are either `get_attr` or constant
- Use `model_transform.split_by_tags()` to split the graph into two
- Look for the `output` node in the constant graph to get names of attrs that will be folded
- Iterate over the non-constant graph and replace placeholders that are using the same name as the attrs with a `get_attr` as well as a dummy attr on the module
- Return these two graphs in a new `FoldedGraphModule`, which is a normal GraphModule but also stores the constant graph on the side along with a `run_folding()` method that will run const folding and update the dummy parameters with the actual folded parameters

Test Plan: Added a couple tests

Reviewed By: 842974287

Differential Revision: D25033996

fbshipit-source-id: 589c036751ea91bb8155d9be98af7dbc0552ea19
2020-12-13 18:06:07 -08:00