This is follow-up of #165037. It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165142
Approved by: https://github.com/albanD
Its unclear why we had disable in the first place. With
install_free_tensors, we are tracing into this hook. A better way would
be to place the tracer without any hook. For now, disable the checking
while dynamo is tracing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164084
Approved by: https://github.com/tugsbayasgalan
Investigated together with @pyemma and @taotaohuang001
## Problem
when calling exported module with dict nested in the args tuple, it will make following complaits
```
Traceback (most recent call last):
File "/home/chzhu/infinitrain/test_torch_export.py", line 32, in <module>
print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)}))
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 848, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 424, in __call__
raise e
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 411, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
return inner()
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1806, in inner
args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc]
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
return fn(*args, **kwargs)
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 81, in _check_input_constraints_pre_hook
flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec)
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 64, in _check_inputs_match
raise ValueError( # noqa: B904
ValueError: Trying to flatten user inputs with exported input tree spec:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a1', 'a2'], [*,
*])]),
TreeSpec(dict, [], [])])
but actually got inputs with tree spec of:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a2', 'a1'], [*,
*])]),
TreeSpec(dict, [], [])]).
Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing.
```
## How to reproduce the issue
```python
import torch
# create a nn.Module with data_batch as input and output as output
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linear = torch.nn.Linear(10, 1)
def forward(self, data_batch):
h1 = self.linear(data_batch["a1"])
h2 = self.linear(data_batch["a2"])
return h1 + h2
# torch export this module
model = MyModel()
example_args_forward = (
{
"a1": torch.randn(10),
"a2": torch.randn(10),
},
)
exported_model = torch.export.export(model, example_args_forward, strict=True)
# save the exported model
torch.export.save(exported_model, "exported_model.pt2")
# load the exported model
exported_model = torch.export.load("exported_model.pt2").module()
# run the exported model
print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)}))
```
## Root Cause
Input spec is encoded as [TreeSpec](582d278983/torch/utils/_pytree.py (L1059)) in torch export. With (args, kwargs) at the top level. When we call the exported model, it has a pre-execution [hook](582d278983/torch/export/_unlift.py (L66)) to check the input TreeSpec matches the received TreeSpec, where in Treespec, the dict key order is preserved. Something like
TreeSpec(dict, ['a2', 'a1'], [*,*])
To workaround this, the input check reorders [kwargs](582d278983/torch/export/_unlift.py (L67)), that is why kwargs can be out of order. But the dict nested in the args is not re-ordered, so any re-ordering of the keys will throw errors.
## Solution
Update eq_spec to handle the dict case, where we only guarantee that key set is the same without ordering constraints.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162618
Approved by: https://github.com/angelayi
This PR is quite large in that it covers most of rough edges in the new strict export flow:
1. Handle nn_module_stack correctly now that we are tracing wrapper module
2. module_call_spec needs to get queried from source directly because we are not running the bytecode anymore.
3. Correct input and output handling.
@diff-train-skip-merge
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162183
Approved by: https://github.com/zhxchen17
Summary: This PR introduces shape guards to export. Previously only value ranges, equalities, and specializations would be tracked for symbolic expressions, and we had a forward hook to check them. Instead now we create a function to check shape guards and call it in the exported program.
Test Plan:
updated several tests
Rollback Plan:
Differential Revision: D80713603
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161178
Approved by: https://github.com/tugsbayasgalan
Summary:
In HF model rwkv, we have parameter mutation under inference mode which should be safe. This PR does multiple things to make sure it works:
1. We execute global autograd mutation while tracing so that we can actually trace through parameter inplace mutation
2. Add support for parameter mutation under inference mode in AOTAutograd
3. Add support for parameter mutation under inference mode in export.
Test Plan:
test
Rollback Plan:
Differential Revision: D79460136
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159661
Approved by: https://github.com/ydwu4
Collects some scattershot improvements made while attempting to enable training for AOTInductor. Non-typing changes are:
1. Swapping a few custom searches for the output node in an FX graph for calling `graph.output_node()`.
2. Removing two unused parameters from `torch.export._unlift._unlift`.
3. Switching handles to constants in `cpp_wrapper_cpu` to use C++ references for memory efficiency.
4. Cleaning out unused, unexported imports from `torch/export/__init__.py`, and adding one missing export to `__all__`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158075
Approved by: https://github.com/Skylion007
Summary:
Adding "from_node" information that indicates which nodes are unlifted in `.module()` call.
The lifted nodes will have "ExportedProgram.module().unlift()" passname in the last entry of from_node.
Test Plan:
```
buck run fbcode//caffe2/test:test_export -- -r test_from_node_metadata_export
```
Rollback Plan:
Reviewed By: angelayi
Differential Revision: D75837494
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155053
Approved by: https://github.com/angelayi
# Fix typo errors across PyTorch codebase
This PR fixes various spelling errors throughout the PyTorch codebase to improve documentation quality and code readability.
## Changes Made
### Documentation Fixes
- Changed "seperate" to "separate" in multiple files:
- `setup.py`: Build system documentation
- `torch/_library/triton.py`: AOT compilation comments
- `torch/csrc/dynamo/compiled_autograd.h`: Node compilation documentation
- `torch/export/_unlift.py`: Pass population comments
- `torch/export/exported_program.py`: Decomposition table notes
### Code Comments and Error Messages
- Changed "occured" to "occurred" in:
- `test/mobile/test_lite_script_module.py`: Exception handling comments
- `torch/export/_draft_export.py`: Error message text
- `aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp`: MAGMA bug comment
- `torch/csrc/utils/python_numbers.h`: Overflow handling comment
- `torch/csrc/jit/OVERVIEW.md`: Graph compilation documentation
- `torch/_dynamo/symbolic_convert.py`: Error explanation
### API Documentation
- Changed "fullfill" to "fulfill" in `torch/distributed/checkpoint/state_dict_loader.py`
- Changed "accross" to "across" in:
- `torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp`
- `torch/distributed/distributed_c10d.py`
## Motivation
These changes improve code readability and maintain consistent spelling throughout the codebase. No functional changes were made; this is purely a documentation and comment improvement PR.
## Test Plan
No testing required as these changes only affect comments and documentation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148262
Approved by: https://github.com/janeyx99
Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
Summary:
Add experimental support for torch.nn.Module as input types.
Before this change, we don't support module inputs but recently we saw some interesting use cases like gpt-fast https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py#L68 where we directly pass in a module input for different variants of the same models.
Since we don't really care about non-param or non-buffer states in non strict mode, we don't care about those either and pretend they are like plain constants during tracing. We treat any module input like a nested container of tensor, and each time we will automatically register a pytree handler for these module types to flatten its state dict into a group of tensors. We will just inline any module method call during tracing like we did for `self` module in export_for_training. This will make input modules' behavior very similar to the training module in typical case, except that we don't record the inputs as parameter or buffers but rather just plain user inputs.
Test Plan: buck run mode/opt caffe2/test:test_export -- -r test_module_input
Differential Revision: D67680827
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143925
Approved by: https://github.com/tugsbayasgalan
In this diff, i make test_torchbind.py tests to handle training IR. Today in the training IR, we don't see the effect token and HOP because this happens at the FunctionalTensorMode. Maybe in the future, we should move this logic up to the training IR so that writing passes etc on training Ir is safer. But for the migration purposes, i think it is ok for now. I also fixed two bugs:
1. ep.module() doesn't register all aliased constants in the module.
2. When we retrace, we need to fakify the original Torchbind object.
3. We don't run any DCE on training IR so we need to add some more torch ops to verifier.
Differential Revision: [D64853530](https://our.internmc.facebook.com/intern/diff/D64853530)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138658
Approved by: https://github.com/ydwu4, https://github.com/zhxchen17
Type annotations for compile_fx.
- Some of the stuff here is pretty complicated (functions which return functions that take functions) so I bailed on those and used `Any` just to get the rest landed.
- There are also changes to type signatures in other files which I did just to let mypy know more about the types in compile_fx.py.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138033
Approved by: https://github.com/Skylion007
When we populate unlifted graph module, we actually only "unlift" constant tensor inputs which is problematic because export de-duplicates aliasing constants. As a result, we only register one constant instead of two constants. This PR fixes that by querying ep.constants table instead of ep.graph_signature.lifted_tensor_constants.
Differential Revision: [D63743111](https://our.internmc.facebook.com/intern/diff/D63743111)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137162
Approved by: https://github.com/pianpwk
Summary:
Fixes T198245910.
In previous diff D60532628 that causes the test failure, we fix the in-consistency caused by constant tensors is accidentally reigistered as buffer by deleting the buffer and re assign them as constant.
However, this broke several existing tests in pyspeech when the exported program is re-traced with torch.jit.trace (which is an anti-pattern we probably should have some alignment), the jit tracer finds this constant tensor requiring grad and errors out.
This PR force constant attr not requiring grad, which is the correct behavior. A better fix is finding out where the constants are created in user code and why it requires grad. But this has low roi so we warn user about it.
Test Plan: See failures in T198245910.
Differential Revision: D60974869
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133031
Approved by: https://github.com/angelayi
Summary:
A re-land of D60006710.
Fixed TrainingIRToRunDecomp failures for test_tensor_attribute_zero_args and also a few re-tracability failures because run_decomposition does a retracing.
edit: also remove the eliminate_dead_code() in _unlift because of one onnx test failure:
a constant tensor attr was lifted as constant_tensor input but it's not used in the graph after aot_autograd due to a short cut in its decomposition. This causes the setattr to be removed by eliminate_dead_code but the graph signature still contains the name of that buffer, which causes an inconsitency between the transformed graph and ep's original signature after _unlift. And it seems that this has happened a few times where some nodes are accidentally removed and we're in an inconsistent state.
The alternative of removing it would be: every time we call elimiate_dead_code, we verify the consistency of the graph with 1. the graph before transformation and 2. all the meta datas but i think this deserves a complete design
edit 2: Also fix the inconsistency of graph signatures when param_constant is marked as lifted_tensor_constants but it's registered as parameters in the output of ep.module().
Differential Revision: D60532628
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132307
Approved by: https://github.com/zhxchen17
Summary:
- make default DCE pass check schema,
- need to rebase onto https://github.com/pytorch/pytorch/pull/131651 after it's in phabricator (for now the change is manually added).
- mark Proxy dump as NotImplemented for better error msg
- Remove Proxy from tensors when dumping models, as Proxy cannot be dumped.
More details in https://docs.google.com/document/d/1G5vmTXjzxoyVGRI2kpA1gQukK_Glyg2NrE0Oh6Nlg9A/edit?usp=sharing.
Test Plan:
CI
```
- buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r qat_conv2d
- test_export.py
- buck2 run 'fbcode//mode/dev-nosan' fbcode//modai/test:test_modai -- -r test_qat_stinson_htp_export
- buck2 run 'fbcode//mode/dev-nosan' fbcode//vizard_projects/ml_depth/tests:test_model -- -r test_qat_model_et
- buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:fx -- -r dce
- buck2 run 'fbcode//mode/dev-nosan' fbcode//bolt/nn/executorch/backends/tests:qnn_test -- -r test_qat_bias=False,use_3d_input=False
- buck2 run 'fbcode//mode/dev-nosan' fbcode//bolt/nn/executorch/backends/tests:qnn_test -- -r test_qat_bias=True,use_3d_input=False
- buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_fold_bn_erases_bn_node
```
Reviewed By: angelayi
Differential Revision: D60319175
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132764
Approved by: https://github.com/angelayi
Summary:
This pr fixes all the places in strict export stack where the output node's meta is not preserved correctly. However, we're getting a new error for the test we intend to fix: `buck2 run caffe2/test/quantization:test_quantization -- -r "test_re_export_preserve_handle"`:
The `get_attr` nodes has wrong metadata. I guess there are more things need to be fixed to get it working but it's beyond the scope of this PR.
Test Plan: buck2 run caffe2/test/quantization:test_quantization -- -r "test_re_export_preserve_handle"
Differential Revision: D60198221
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131706
Approved by: https://github.com/yushangdi
Summary:
This diff reverts D59561509
D59561509: [FX][export] DCE pass, check schema for node impurity (#130395) by yushangdi causes the following test failure:
Tests affected:
- [cogwheel:cogwheel_mtia_cmf_m5_shrunk_test#test_flow_with_verification](https://www.internalfb.com/intern/test/844425041436985/)
Here's the Multisect link:
https://www.internalfb.com/multisect/6533402
Here are the tasks that are relevant to this breakage:
T191383430: 10+ tests unhealthy for ads_mtia_inference
The backout may land if someone accepts it.
If this diff has been generated in error, you can Commandeer and Abandon it.
Test Plan: NA
Differential Revision: D60029318
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131341
Approved by: https://github.com/angelayi
Fixed TrainingIRToRunDecomp failures for test_tensor_attribute_zero_args and also a few re-tracability failures because run_decomposition does a retracing.
**edit:** also remove the eliminate_dead_code() in _unlift because of one onnx test failure:
a constant tensor attr was lifted as constant_tensor input but it's not used in the graph after aot_autograd due to a short cut in its decomposition. This causes the setattr to be removed by eliminate_dead_code but the graph signature still contains the name of that buffer, which causes an inconsitency between the transformed graph and ep's original signature after _unlift. And it seems that this has happened a few times where some nodes are accidentally removed and we're in an inconsistent state.
The alternative of removing it would be: every time we call elimiate_dead_code, we verify the consistency of the graph with 1. the graph before transformation and 2. all the meta datas but i think this deserves a complete design.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130990
Approved by: https://github.com/pianpwk
Fixes the failure in `test/export/test_export_training_ir_to_run_decomp.py ` caused by dead code elimination removing node with side effects.
For background, in export, we may want to export higher-level IRs that are not functional, so we need to check for side effects more carefully.
A call_function node is impure if it has at least one mutable argument.
Fixed the tests below:
test_to_module_with_mutated_buffer_multiple_update_sub_later
test_export_input_mutation_static_shape
test_buffer_util
Another attempt modifying the original DCE pass is made in PR #130395, but it breaks some other tests, so here we add a flag and use it for export only.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130552
Approved by: https://github.com/pianpwk
This PR intends to support the aten operations with the `out` tensor.
Currently, the AOT compile always does **NOT** keep input tensor mutations. According to the comments, this is because it has not encountered such a use case.
> For now there's no use case involving keeping input mutations in the graph (which we can only do in the inference case anyway). We can add this later if we need to.
However, for aten operations, it is popular that the `out` tensor is an input parameter and needs to be mutated. This PR intends to support it by adding a `keep_inference_input_mutations` flag to `aot_inductor.keep_inference_input_mutations`. This flag can provide flexibility to the callee in deciding whether the AOT compile needs to keep input tensor mutations in the graph.
Take `clamp` as an example as follows.
```python
out_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(-2.0)
inp_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(1.0)
min_tensor = inp_tensor - 0.05
max_tensor = inp_tensor + 0.05
torch.clamp(input=inp_tensor, min=min_tensor, max=max_tensor, out=out_tensor)
```
W/O this PR
```python
def forward(self):
arg0_1: "f32[128]"; arg1_1: "f32[128]"; arg2_1: "f32[128]"; arg3_1: "f32[128]";
arg0_1, arg1_1, arg2_1, arg3_1, = fx_pytree.tree_flatten_spec([], self._in_spec)
clamp_min: "f32[128]" = torch.ops.aten.clamp_min.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
clamp_max: "f32[128]" = torch.ops.aten.clamp_max.Tensor(clamp_min, arg2_1); clamp_min = arg2_1 = None
return (clamp_max, clamp_max)
```
W/ this PR
```python
def forward(self):
arg0_1: "f32[128]"; arg1_1: "f32[128]"; arg2_1: "f32[128]"; arg3_1: "f32[128]";
arg0_1, arg1_1, arg2_1, arg3_1, = fx_pytree.tree_flatten_spec([], self._in_spec)
clamp_min: "f32[128]" = torch.ops.aten.clamp_min.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
clamp_max: "f32[128]" = torch.ops.aten.clamp_max.Tensor(clamp_min, arg2_1); clamp_min = arg2_1 = None
copy_: "f32[128]" = torch.ops.aten.copy_.default(arg3_1, clamp_max); arg3_1 = clamp_max = None
return (copy_,)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124926
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/angelayi