Summary: This PR introduces shape guards to export. Previously only value ranges, equalities, and specializations would be tracked for symbolic expressions, and we had a forward hook to check them. Instead now we create a function to check shape guards and call it in the exported program.
Test Plan:
updated several tests
Rollback Plan:
Differential Revision: D80713603
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161178
Approved by: https://github.com/tugsbayasgalan
Summary: ONNX team and recent transformer upgrade ran into this error and we also ran into during our export benchmarking. This diff makes it possible to trace through vmap implementation in pre-dispatch IR. Note that we don't support serializing functorch ops in pre-dispatch IR and in the future, we should desugar them to post-grad ops.
The implementation strategy is:
1. We add python wrappers around vmap APIs so that we attach custom torch function handler that is only on during non-strict export. The reason is we don't want to add this to default torch_function handler because it will break BC.
2. Some dynamo changes to make sure it picks up new python wrapper APIs. The reason is when we do strict export, we need to re-materialize these APIs in pre-dispatch IR from torch IR. We can avoid this by special casing in dynamo for export to proxy different API calls but i feel that is too much chaos because you need to be able to proxy 2 different variants of same vmap API.
Test Plan: CI
Differential Revision: D75623875
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154650
Approved by: https://github.com/ezyang, https://github.com/zou3519
As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo
This PR adds strict typing support to a critical set of files for dynamo, `source.py` and the base `_guards.py`
Running
```
mypy torch/_dynamo/source.py torch/_guards.py --linecount-report /tmp/coverage_log
```
| -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered |
| -------- | ------- | -------- | ------- | ------- | ------- | ------- |
| Main | 1227 | 2208 | 55.57% | 207 | 362 | 57.18% |
| This PR | 2217 | 2217 | 100.00% | 362 | 362 | 100.00% |
| Delta | +990 | +9 | +44.43% | +155 | 0 | +42.82% |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158397
Approved by: https://github.com/anijain2305
Basically adds native _IntWrapper support to dynamo. Here's my process of trying to make symint input support work on dynamo, and how I ended up with this approach [(doc)](https://docs.google.com/document/d/1GvNRQd8BnxlMay_hrEVgEta6VUeUW_hcFeRuB7q1nDY/edit?tab=t.0).
What I did was, before passing inputs to dynamo.export, I first wrap them with a class, `_IntWrapper`. When processing dynamic shapes, I will then add the corresponding dynamic shape specification to the `dynamism` field stored on the `_IntWrapper`. If there is no dynamism specified, then this will get unwrapped back to an integer. When dynamo tracing, when we encounter an `_IntWrapper`, we will convert this to a symint if the dynamism was specified as `Dim.DYNAMIC/AUTO`. Dynamo will then trace a graph that contains symint inputs, which will get passed to AOTAutograd and so on.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152677
Approved by: https://github.com/pianpwk
Summary:
Fixes https://github.com/pytorch/pytorch/issues/151476
The `custom_meta` collected from `mod` has keys that follow name of nodes in `mod`, which are inconsistent with the node names after the naming pass. For example a constant `b` will become `c_b`.
Test Plan: buck2 run caffe2/test:test_export -- -r test_run_decompositions_keep_tensor_constant_metadata
Differential Revision: D73703068
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152241
Approved by: https://github.com/angelayi
Changes decomposition behavior of `aten.to` to respect the aliasing/non-aliasing behavior in eager, and to specialize to the input/conversion dtype & device.
Before change: we always decompose `aten.to` into `_to_copy`, regardless of aliasing behavior. This leads us to ban mutations on the result of `_to_copy` when aliased, since we can't guarantee correct program semantics. This meant users had to explicitly call `.clone()` before mutating. In the special cases where we don’t ban mutations (e.g. dtype conversion), we add runtime assertions on the input & conversion dtype/devices in the decomposed program (see https://github.com/pytorch/pytorch/pull/142420).
After change: we decompose to the aliasing/non-aliasing behavior that matches eager, allowing mutations in all cases. We also add dtype/device assertions for all `aten.to` ops, starting in the pre-dispatch graph, basically specializing the program to the dtype/devices.
Differential Revision: D71229547
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149235
Approved by: https://github.com/tugsbayasgalan
The `export` API takes a `nn.Module` and traces its `forward` method. However sometimes it is useful to export different methods of a `nn.Module`, either as a one-off for debugging or as a set of methods that are called in some sequence outside `export` (e.g., `encode` / `decode`). When multiple methods of the same module instance are exported, they should share the same of the common module instance.
This PR adds a couple of utils in `torch._export.utils` for this workflow.
The `wrap_method` util wraps a method as a `nn.Module` that can then be exported. See included test. We recommend using the same module instance to export multiple methods on that instance, in which case they are guaranteed to share state. On serde, this state sharing is lost, so we provide another util, `sync_state`, to re-sync the state.
These utils are meant to be eventually replaced by API-level changes, but for now this can unblock users who need this workflow. In particular, in the future we can accept one or multiple method entrypoints, with their own args / kwargs / dynamic shape specifications, which can create a variant of `ExportedProgram` with multiple graphs that share state; then we can automatically ensure that the state sharing is preserved through serde.
Differential Revision: [D69960801](https://our.internmc.facebook.com/intern/diff/D69960801/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147573
Approved by: https://github.com/tugsbayasgalan
Previously, in non-strict path, we always error when trying to inplace update a constant tensor because those constant tensors are not actually wrapped by functional tensors. This is correct behaviour in torch.compile, because dynamo makes all constant tensors into buffers and AOTDispatcher just lifts them and wraps them in functional tensors. However, in non-strict, there is no such step that registers constants as buffers so AOTDispatcher panics when it sees these dangling constant tensors when functioanalizing.
Due to recent change in the IR, this is no longer an issue in non-strict path because we don't call AOTDispatcher at training IR level, but now it is a problem for both strict and non-strict when we lower to inference. (lowering to inference is very similar to non-strict tracing) As a result, we have at least one external (https://github.com/pytorch/pytorch/issues/141336) and internal issues reported due to this difference.
To fix this, there are two ways:
1. Make functionalization be aware of constant tensors and map them to functional tensors on the fly. This makes functionalization invariant uglier and could potentially open up a gate for more nasty bugs.
2. Special handle this in export. This seems more aligned with what dynamo does today so i think we should do it this way. I think the current state could benefit from more refactors to make the run_deocmpositions to be more similar to strict export (because both of them now handle this constant registerinig logic) but it is bit complicated to do it now because strict export version of this logic is also not complete because it doesn't take into account of export graph renaming pass etc). I will follow up with more refactors after this PR (T213466691) to unblock users faster.
For future reference:
Why are we not doing "turning constants into non-persistent buffers and never de-register"? The reason is because in some internal models, they rely on module.to to reliably work to move params/buffers to correct device. As a result, buffers are moved while constants are not. In composibility meeting, we agreed that export won't do device agnostic tracing going forward (it will provide a way to specify FakeTensor in CPU that can be configured to be run on GPU), so after that is done, we can always turn constants into non-persistent buffers which will simplify export's constant handling.
Differential Revision: [D68610739](https://our.internmc.facebook.com/intern/diff/D68610739)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145593
Approved by: https://github.com/avikchaudhuri
Summary:
Add experimental support for torch.nn.Module as input types.
Before this change, we don't support module inputs but recently we saw some interesting use cases like gpt-fast https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py#L68 where we directly pass in a module input for different variants of the same models.
Since we don't really care about non-param or non-buffer states in non strict mode, we don't care about those either and pretend they are like plain constants during tracing. We treat any module input like a nested container of tensor, and each time we will automatically register a pytree handler for these module types to flatten its state dict into a group of tensors. We will just inline any module method call during tracing like we did for `self` module in export_for_training. This will make input modules' behavior very similar to the training module in typical case, except that we don't record the inputs as parameter or buffers but rather just plain user inputs.
Test Plan: buck run mode/opt caffe2/test:test_export -- -r test_module_input
Differential Revision: D67680827
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143925
Approved by: https://github.com/tugsbayasgalan
We added an is_export flag under torch.compiler.is_exporting. This comes handy when we try to do some special logic in user-level and system-level (e.g. in upper of the stack).
In increasing-scope:
- `_is_fx_tracing` is set to True when we use under symbolic_trace or make_fx.
- `is_exporting` is set to True when we're doing strict or non-strict export, which internally has a step that calls make_fx and set _is_fx_tracing to be True.
- `is_compiling` is set to True when we're either doing strict, non-strict export or torch.compile.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142425
Approved by: https://github.com/avikchaudhuri
A bunch of auto dynamic shape tests would fail non-strict retraceability because when checking input constraints, we'd compare non-trivial expressions, which would require / affect shape env.
```
... is not tracked with proxy for <torch.fx.experimental.proxy_tensor._ModuleStackTracer object ...
```
I've also observed this bug internally.
This PR does an early check on whether args passed have concrete shapes, and only then proceeds: as before, we
1. try to unify / solve with the arg dim when the corresponding placeholder node dim is symbolic in one symbol
2. check directly if the placeholder node dim is concrete
3. otherwise defer to run time.
Differential Revision: [D67359596](https://our.internmc.facebook.com/intern/diff/D67359596/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143442
Approved by: https://github.com/tugsbayasgalan
Over time, a large number of the existing type ignores have become irrelevant/unused/dead as a result of improvements in annotations and type checking.
Having these `# type: ignore` linger around is not ideal for two reasons:
- They are verbose/ugly syntatically.
- They could hide genuine bugs in the future, if a refactoring would actually introduce a bug but it gets hidden by the ignore.
I'm counting over 1500 unused ignores already. This is a first PR that removes some of them. Note that I haven't touched type ignores that looked "conditional" like the import challenge mentioned in https://github.com/pytorch/pytorch/pull/60006#issuecomment-2480604728. I will address these at a later point, and eventually would enable `warn_unused_ignores = True` in the mypy configuration as discussed in that comment to prevent accumulating more dead ignores going forward.
This PR should have no effect on runtime at all.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142325
Approved by: https://github.com/Skylion007, https://github.com/janeyx99
For better tracking, we need to make maybe aliasing/mutating ops with proper tag. We need to special case native_batch_norm because it is not a CIA but has a wrong schema. I guess native_batch_norm will be removed at some point, so until then we just keep it around.
D60347117
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131990
Approved by: https://github.com/bdhirsh
Summary:
When we have both `set_grad` and `autocast` HOP, name collision might happen when we try to inline a node.
For exmaple, for a GraphModule like this:
```
GraphModule(
(submod_0): GraphModule(
(submod_1): GraphModule()
)
(submod_1): GraphModule()
(submod_2): GraphModule()
)
```
when we inline `submod_0`, we might accidentally overwrite `submod_1`.
In this PR, we fix this by check if the graph module already has an attribute with the same name, if so, we use the next "submod_{i}", until no name collision.
Partially fixes https://github.com/pytorch/pytorch/issues/140589.
Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_predispatch_autocast_and_set_grad
```
Differential Revision: D66200994
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141169
Approved by: https://github.com/angelayi
In this PR, we implement lazy dictionary for export decomp behaviour for following reasons:
1. Custom op loading can happen after import time, as a result, the decomp table might not be able to pick up the decomp. Therefore we try to delay materialization as late as possible.
I intentionally seperated out the core_aten_decomp to not have any custom CIA ops in this PR to mitigate the risk of getting reverted but in the future, core_aten_decomp under torch/_decomp will exist as an alias to official export table (torch.export.default_decompositions)
Differential Revision: [D64140807](https://our.internmc.facebook.com/intern/diff/D64140807)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137650
Approved by: https://github.com/justinchuby, https://github.com/bdhirsh
Resolve#137540
Summary:
We might get different state_dict and named_parameters result when the module has registered custom state_dict_hooks.
For exported_program's state_dict, we want the state_dict to reflect the actual module hierarchy at runtime, and it might be different from the model's state_dict() output if the model has state_dict hooks.
To do weight swapping, one needs to either re-export or turn-off the hooks when saving model's state_dict().
Previously, ExportedProgram uses nn.Module's state_dict() method to populate its own state_dict, but it doesn't work for some models (e.g. llama3_3_vision) because ExportedProgram's state_dict and an nn.Module's state_dict have some subtle differences semantically.
nn.Module's state_dict is about how the state should be serialized, and it reflects the structure of the original user model code. In contrast, export specializes on a “run” of a model, and its state_dict needs to reflect the runtime module hierarchy.
One example where these two are different is TorchTune's Llama3_2_vision text decoder. Here, a FusionLayer is added as a local optimization and it is not part of the "static model definition". In runtime, we have mod.layers[3].layer.sa_norm.scale.
But in nn.Module's state_dict, the authors of the model added a state_dict hook to remove the "layer" in mod.state_dict() to reflect the static model definition, so we have mod.state_dict()["layers.3.sa_norm.scale"].
In this Diff, we change ExportedProgram to populate its state_dict using named_parameters() and named_buffers() instead. So in ExportedProgram's state_dict, we have "layers.3.layer.sa_norm.scale", which reflects the runtime module hierarchy.
Now one problem this presents is weight swapping. Since ExportedProgram's state and the model's state is not the same anymore, weight swapping procedure also needs to change slightly.
In internal Ads and RecSys models deployment, weight swapping is where they have one model that is currently being being deployed and serving traffic, and they want to swap out the weights with newly trained model weights without having to redo the whole exporting/lowering process and create a new artifact. So they would move the deployed model’s pointer to the state dict over to the new state dict. Because of this, it’s previously a requirement that the FQNs are matching between the exported and the eager model’s state dict.
The new ExportedProgram's state dict still supports weight swapping, but the state_dict to be swapped needs to be obtained from torch.export.exported_program instead of model.state_dict() if the model has state_dict hooks.
The new requirement is that the FQNs are matching between the exported’s state dict and the state_dict obtained from `_disabled_load_state_dict_hooks(M)` context manager. One benefit of having this new API is that we are now in full control within export of gathering and updating the model state.
If a model doesn't have any state_dict hooks, one can still use model.state_dict() for weight swapping, so it's BC.
Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_export_for_training_with_state_dict_hooks
```
Differential Revision: D64080561
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137609
Approved by: https://github.com/angelayi, https://github.com/pianpwk
Summary:
Fixes https://github.com/pytorch/pytorch/issues/134778
The previous D62304294 broke some executorch tests. It has already been reverted.
In this diff, `_collect_param_buffer_metadata()` is modified in a way that when a `call_function` node is encountered and its input nodes include `get_attr`. We skip the fields that have been collected previously and only collect rest of the fields. This prevents over-writing.
Test Plan:
```
buck2 test 'fbcode//mode/dev-nosan' fbcode//executorch/backends/xnnpack/test:test_xnnpack_ops
buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_re_export_preserve_handle
buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_run_decompositions_preserve_handle
```
Differential Revision: D62514208
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135720
Approved by: https://github.com/zhxchen17, https://github.com/jerryzh168