Summary:
We split the refactoring in two parts for forward compatibility concerns
First, we land the deserialization (loading part)
Then, we land the serialization (saving part)
Save weights and constants as individual files in PT2 archive. Each weight/constant will be saved as raw bytes, unless it is a custom object (TorchBind object) or a non-fake tensor subclass, for these two special cases we still save them using pickle.
The metadata of saved tensors along with the file name will be saved as `PayloadMeta`.
The mapping from FQN to `PayloadMeta` will be saved as `PayloadConfig` under `WEIGHTS_CONFIG_FORMAT` and `CONTANTS_CONFIG_FORMAT`
This changes the serialization in python side when calling `torch.export.save()`.
For deserialization in python `torch.export.load()`, we make it BC-safe by allowing loading legacy format weights/constants.
For deserialization in C++ `torch/nativert/ModelRunner.cpp`, we make this a BC breaking change as currently the OSS ModelRunner API is not being used.
The file structure
```
├── archive_format
├── archive_version
├── byteorder
├── .data
│ ├── serialization_id
│ └── version
├── data
│ ├── sample_inputs
│ │ └── model.pt
│ ├── constants
│ │ ├── tensor_0
│ │ ├── tensor_1
│ │ └── model_constants_config.json
│ └── weights
│ ├── weight_0
│ ├── weight_1
│ ├── weight_2
│ ├── weight_3
│ └── model_weights_config.json
└── models
└── model.json
```
Test Plan:
CI
Rollback Plan:
Differential Revision: D80035490
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160394
Approved by: https://github.com/SherlockNoMad
Summary:
[The diff was reverted due to CLA error, in the process of retrieving account]
Previous error message
```
RuntimeError: Expected input at *args.<unknown location>.shape[0] to be equal to 4096, but got 7680. If you meant for this dimension to be dynamic, please re-export and specify dynamic_shapes (e.g. with Dim.DYNAMIC)
```
New error message
```
RuntimeError: Expected input at *args.[0].supervision_input.weight.shape[0] to be equal to 4096, but got 7680. If you meant for this dimension to be dynamic, please re-export and specify dynamic_shapes (e.g. with Dim.DYNAMIC)
```
Test Plan:
```
buck test mode/opt apf/rec/ir/tests:ir_export_deserialize_test
```
https://www.internalfb.com/intern/testinfra/testrun/4785074906254375
```
buck run mode/opt caffe2/test:test_export -- -r unflatten
```
```
Ran 413 tests in 208.414s
OK (skipped=1, expected failures=13)
```
Rollback Plan:
Differential Revision: D80487367
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160919
Approved by: https://github.com/angelayi
Summary: Sometimes the call history recorded in a `nn_module_stack` does not have the stack property, where each FQN is a prefix of the next FQN. This can cause errors during `unflatten`. Instead of erroring we now drop entries from such a `nn_module_stack` to restore the stack property. This effectively leads to less unflattening: the last FQN in the call history before the stack property was broken keeps the entire flat subgraph of its call.
Test Plan:
added test, updated another
Rollback Plan:
Differential Revision: D79204669
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159418
Approved by: https://github.com/angelayi
Preview: https://docs-preview.pytorch.org/pytorch/pytorch/157750/export.html
Changes:
* Rename draft_export.md -> export.draft_export.md for consistency.
* Removed non-strict section in export, instead pointed to programming model doc.
* Extended "Expressing Dynamism" section to include Dim hints, ShapeCollection, and AdditionalInputs.
* Removed Specialization section in favor of programming model doc
* Added pt2 archive doc
* Cleaned up sidebar
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157750
Approved by: https://github.com/pianpwk
Summary:
1. Adding `input` field to `_adapt_flat_args` function
2. In `process_forward_inputs`, `reorder_kwargs` will now do nothing if no kwargs are provided (previously would error)
3. Pass `args` as input to `_adapt_flat_args`
These changes are made to update the InputAdapter
see more context in D73811508
Test Plan: see D73811508
Differential Revision: D73945419
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152575
Approved by: https://github.com/angelayi
Summary:
Recall that we use "ivals" to track intermediate values of mutations during unflattening. Previously, for each such intermediate value, we would create a hidden shared attribute that would be updated / read by respective submodules.
Unfortunately this scheme doesn't work when some but not all of those submodules are swapped out. This is because the swapped in submodules have no knowledge of these hidden attributes. Thus the submodules that are not swapped out end up reading / updating dangling state.
This PR does away with these hidden attributes. Instead, we directly read the underlying buffer or placeholder that was updated, and update those underlying buffers and placeholders in place. This makes the graphs look much closer to their eager origins.
Test Plan: added some tests, ensured existing tests pass
Differential Revision: D71203469
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149206
Approved by: https://github.com/tugsbayasgalan
Summary: Add support for inputs that no longer exist in `input_fields`, but is not actually used by the original program. In this case, we just give it a dummy input based on the node's metadata.
Test Plan: Verified for S488841
Differential Revision: D69328093
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147238
Approved by: https://github.com/pianpwk
If a user passes in a namedtuple as an input, currently the input TreeSpec looks like: `TreeSpec(type=namedtuple, context=”class_fqn”, children_spec=[*, *])`
The user then saves the program containing this input TreeSpec. But what happens if they load it in a new environment where `class_fqn` now contains an additional field?
This means that the exported program is now expected to take in another input. But since those fields were not used in the original program, users should be able just drop those additional fields and the program will run successfully. This is needed/used in APS where they use unflattener's adapter to adapt the inputs based on the previously saved treespecs.
There are a couple of [solutions](https://docs.google.com/document/d/1V4ZSdy-8PUISWc8RqvGu3DU01BVegJhHHPWqa1Io7Eg/edit?tab=t.0) for how we can address this, but eventually we settled on saving a side table mapping namedtuple types to their list of field names, which can then be accessed by the adapter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145956
Approved by: https://github.com/zhxchen17
When we unflatten, the submodules we generate (`InterpreterModule` or `InterpreterModuleDispatcher`) are not related by type to the original submodules `N`. This makes `isinstance(mod, N)` checks fail. Since we do not have the original types after export, the best we can do is expose a `type_name()` method that carries the original type name, which we do carry in `nn_module_stack` entries.
Differential Revision: [D67526542](https://our.internmc.facebook.com/intern/diff/D67526542/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143664
Approved by: https://github.com/tugsbayasgalan
Combining several fixes to unflatten for bugs revealed by random graph testing.
The fixes target two categories of bugs:
1. Some bugs show up as exponential blowups for largish system of nn modules. These are fixes by converting lists to sets, using caching, or otherwise rewriting to reuse computation more effiicently.
2. Other bugs were due to missing intermediate modules created when attributes such as submodules and buffers are accessed through longish paths before calling the corresponding intermediate modules, or missing attributes such as buffers and constants in submodules corresponding to multiple calls.
Differential Revision: [D66659795](https://our.internmc.facebook.com/intern/diff/D66659795/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142141
Approved by: https://github.com/ydwu4
Over time, a large number of the existing type ignores have become irrelevant/unused/dead as a result of improvements in annotations and type checking.
Having these `# type: ignore` linger around is not ideal for two reasons:
- They are verbose/ugly syntatically.
- They could hide genuine bugs in the future, if a refactoring would actually introduce a bug but it gets hidden by the ignore.
I'm counting over 1500 unused ignores already. This is a first PR that removes some of them. Note that I haven't touched type ignores that looked "conditional" like the import challenge mentioned in https://github.com/pytorch/pytorch/pull/60006#issuecomment-2480604728. I will address these at a later point, and eventually would enable `warn_unused_ignores = True` in the mypy configuration as discussed in that comment to prevent accumulating more dead ignores going forward.
This PR should have no effect on runtime at all.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142325
Approved by: https://github.com/Skylion007, https://github.com/janeyx99
With largish systems of nn modules with buffers, sinking params suffered from some kind of exponential blowup that is easily fixed by using a set instead of a list to keep track of unlifted buffer placeholders.
Test Plan: added random dag test that failed previously
Differential Revision: D66457661
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141494
Approved by: https://github.com/angelayi
Handling of nested modules in unflatten had several bugs, which were caught by trying to preserve module call signatures for nested modules.
* A module `k` encountered when calling `k.n()` before `k()` used to become an empty nn module. This caused some information to be dropped when `k()` was eventually called. Relatedly, we would also lose call counts for `k.n()` through different paths (say, when `k()` calls `n()`).
* Deleting call-indexed modules and patching up their call sites was broken for nested modules when creating dispatcher modules, because of silliness when handling their fqns.
An interesting aside is that we used random graph generation for testing some of these changes. A future PR will add the infra to create tests using these random graphs.
Differential Revision: D66192799
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141066
Approved by: https://github.com/angelayi
* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980
Approved by: https://github.com/justinchuby, https://github.com/malfet
Differential Revision: [D65307961](https://our.internmc.facebook.com/intern/diff/D65307961/)
This PR introduces the concept of a "dispatcher" module `n` that carries multiple interpreter modules `n`, `n@1`, `n@2`, etc., each corresponding to a particular call of `n` and thus might carry a different specialized graph. We only do this when we're preserving module call signatures for `n`. The carried modules have the same number and order of calls to `n` appearing in the original module / exported program. In the unflattened module, all those calls go to the "dispatcher" module which internally tracks how many calls have been made so far and invokes the corresponding interpreter module. We reset this tracking after a successful or unsuccessful run of the unflattened module.
Overall this makes swapping easier when module call signatures are preserved.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139439
Approved by: https://github.com/tugsbayasgalan
ghstack dependencies: #139438
# Why?
I want the following code to work.
minimal repro:
```
class M(torch.nn.Module):
def forward(self, dilate_flag):
return dilate_flag.item()
input1 = (torch.tensor([1], dtype=torch.bool, device="cuda"),)
model = M().cuda()
ep = torch.export.export(model, input1, strict=True)
path = torch._inductor.aot_compile(ep.module(), input1)
aot_model = torch._export.aot_load(path, device="cuda")
actual_output = aot_model(*input1)
```
error: AssertionError: Encountered an unsupported object of type <class 'torch.SymBool'> while writing the metadata for exported program
second error will be handled by https://github.com/pytorch/pytorch/pull/138760
# Motivation
I could technically bypass it with a torch.int tensor. However, it doesn't work with torch.cond. I want the following to work. It would also require https://github.com/pytorch/pytorch/pull/138760 for aot compile to work.
```
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.dilate_flag = 0
def forward(self, dilate_flag):
self.dilate_flag = dilate_flag.item()
def true_fn(dilate_flag):
return dilate_flag.clone()
def false_fn(dilate_flag):
return dilate_flag.clone()
torch.cond(
self.dilate_flag,
true_fn,
false_fn,
(dilate_flag,),
)
return self.dilate_flag
input1 = (torch.tensor([1], dtype=torch.bool, device="cuda"),)
input2 = (torch.tensor([0], dtype=torch.bool, device="cuda"),)
inputs = (input1, input2)
model = M().cuda()
for input in inputs:
expected_output = model(*input)
ep = torch.export.export(model, input, strict=False)
path = torch._inductor.aot_compile(ep.module(), input)
aot_model = torch._export.aot_load(path, device="cuda")
actual_output = aot_model(*input)
assert (
expected_output == actual_output
), f"henry they are not equal {expected_output} != {actual_output}"
```
Differential Revision: D64867504
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138765
Approved by: https://github.com/ydwu4
Summary:
Unflatten was broken for HOPs for a couple of reasons:
(1) we didn't expect `get_attr` nodes in the exported program, but they can occur to hold graph arguments to HOPs; such attributes must be moved from the exported program to the corresponding unflattened submodule containing the HOP call.
(2) we don't record metadata for graph arguments on serialization (there's nothing to hold it in our schema), and accordingly the `get_attr` nodes we create on deserialization don't have `nn_module_stack` metadata, which obviously wrecks unflatten.
Test Plan: added a couple of tests
Differential Revision: D65013647
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138978
Approved by: https://github.com/zhxchen17