Summary:
https://docs.google.com/document/d/1QJJEGnj2nHGPODlw38BEG3KLLCOTfdOVjPrNQbz_LM8/edit#bookmark=id.lp80wfshq130
Changes:
* `torch.export` will return a functional ATen graph w/o decompositions
* `exported_program.run_decompositions(decomposition_table)` will optionally take a decomposition table, and run decompositions on the exported program, returning a new exported program. By default we will run the Core ATen decomposition table.
Calling convention for Executorch stays the same:
```
pre_autograd_graph = capture_pre_autograd_graph(f, args, ...)
aten_graph_no_decomps = torch.export.export(pre_autograd_graph, args, ...)
# Within to_edge we decompose to core aten and then convert to edge
edge_graph = exir.to_edge(aten_graph_no_decomps)
```
Test Plan: CI
Differential Revision: D49742989
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110410
Approved by: https://github.com/ydwu4
## Context
Add decompositions for `aten.max`, `aten.min`, and `aten.var_mean`. These operators follow a pattern of returning a tuple of outputs from two component operators:
```
aten.max(x) -> return aten.amax(x), aten.argmax(x)
aten.min(x) -> return aten.amin(x), aten.argmin(x)
aten.var_mean(x) -> return aten.var(x), aten.mean(x)
```
For `var_mean`, the `refs` implementation was doing something similar, so I changed it to call `torch.` ops instead like was done for other `refs` implementations previously. cc: @peterbell10 @lezcano
Note that Inductor lowers all these directly, so they are excluded from the Inductor decomp table.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110906
Approved by: https://github.com/manuelcandales
A resubmit of https://github.com/pytorch/pytorch/pull/108447. Copy over the descriptions:
This is a follow-up of the discussion in https://github.com/pytorch/pytorch/pull/108356, where we want to repalce source_fn with source_fn_stack
Before this PR, for the following example:
```python
backend = EagerAndRecordGraphs()
@torch.compile(backend=backend, fullgraph=True)
def cond_f(pred, pred2, x, y):
def true_fn(pred2, x, y):
return x + y
def false_fn(pred2, x, y):
def true_fn2(x, y):
return x.sin() - y.cos()
def false_fn2(x, y):
return x.cos() - y.sin()
return control_flow.cond(pred2, true_fn2, false_fn2, (x, y))
return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y))
```
The graph captured is shown below:
```python
class GraphModule(torch.nn.Module):
def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
l_pred_ = L_pred_
l_pred2_ = L_pred2_
l_x_ = L_x_
l_y_ = L_y_
cond_true_1 = self.cond_true_1
cond_false_1 = self.cond_false_1
cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]); l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None
return (cond,)
class GraphModule(torch.nn.Module):
def forward(self, l_pred2_, l_x_, l_y_):
add = l_x_ + l_y_; l_x_ = l_y_ = None
return add
class GraphModule(torch.nn.Module):
def forward(self, l_pred2_, l_x_, l_y_):
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]); l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None
return cond
class GraphModule(torch.nn.Module):
def forward(self, l_x_, l_y_):
sin = l_x_.sin(); l_x_ = None
cos = l_y_.cos(); l_y_ = None
sub = sin - cos; sin = cos = None
return sub
class GraphModule(torch.nn.Module):
def forward(self, l_x_, l_y_):
cos = l_x_.cos(); l_x_ = None
sin = l_y_.sin(); l_y_ = None
sub = cos - sin; cos = sin = None
return sub
```
the source_fn for inner cond, sin, cos will be a (name, target) tuple:
```
('cond', <torch._ops.HigherOrderOperator object at xxx>)
('sin', 'sin')
('cos', 'cos')
('sub'. <built-in function sub>)
```
After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list.
```
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')]
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)]
```
Test Plan:
See added tests in test_higher_order_ops.py and modify existing test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108595
Approved by: https://github.com/angelayi, https://github.com/zou3519
Recently we updated the `export` API to take an experimental `dynamic_shapes` argument that was meant to subsume the existing `constraints` argument.
This PR deprecates `constraints` (with a warning on its use, but without actually removing it). Simultaneously it replaces all uses of `constraints` in docs, examples, and tests with corresponding uses of `dynamic_shapes` (preserving behavior). This exercise fortunately revealed some minor bugs in the implementation which have also been fixed in this PR.
Some uses of `constraints` still remain, e.g., when `torch._dynamo.export` is called directly. (Meta-internal uses will be updated in a separate diff.)
Differential Revision: D49676049
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110143
Approved by: https://github.com/tugsbayasgalan
## Context
Add existing decomps for `lift_fresh`, `split.Tensor`, and `unbind` to the core ATen decomposition table. Do not use them in inductor, since Inductor currently lowers these directly.
One note though is that `lift_fresh`'s decomposition has a note saying it's not correct under autograd. However, my understanding is that these decompositions are registered to the `"post_autograd"` decomposition table, meaning autograd wouldn't be a factor. Would like some confirmation that this premise is correct.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110102
Approved by: https://github.com/jansel
Fixes https://github.com/pytorch/pytorch/pull/102577#issuecomment-1650905536
Serializing to json is more stable, and renamed the API:
```
# Takes in a treespec and returns the serialized treespec as a string. Also optionally takes in a protocol version number.
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
# Takes in a serialized treespec and outputs a TreeSpec
def treespec_loads(data: str) -> TreeSpec:
```
If users want to register their own serialization format for a given pytree, they can go through the `_register_treespec_serializer` API which optionally takes in a `getstate` and `setstate` function.
```
_register_treespec_serializer(type_, *, getstate, setstate)
# Takes in the context, and outputs a json-dumpable context
def getstate(context: Context) -> DumpableContext:
# Takes in a json-dumpable context, and reconstructs the original context
def setstate(dumpable_context: DumpableContext) -> Context:
```
We will serialize to the following dataclass, and then json.dump this it to string.
```
class TreeSpec
type: Optional[str] # a string name of the type. null for the case of a LeafSpec
context: Optional[Any] # optional, a json dumpable format of the context
children_specs: List[TreeSpec],
}
```
If no getstate/setstate function is registered, we will by default serialize the context using `json.dumps/loads`. We will also serialize the type through `f"{typ.__module__}.{typ.__name__}"`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106116
Approved by: https://github.com/zou3519
Summary: Previously serializing graphs using map would error
because map returns a singleton tensor list rather than a
single tensor. So this diff adds support for if a higher order operator
returns a list of tensors as output.
We also run into an issue with roundtripping the source_fn on
map nodes/subgraphs. The source_fn originally is
<functorch.experimental._map.MapWrapper object at 0x7f80a0549930>, which
serializes to `functorch.experimental._map.map`. However, we are unable
to construct the function from this string. This should be fixed once
map becomes a fully supported operator like
torch.ops.higher_order.cond.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: [D48631302](https://our.internmc.facebook.com/intern/diff/D48631302)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107837
Approved by: https://github.com/zhxchen17
ghstack dependencies: #107818
Some NvidaTRT folks were asking for a way to integrate the serialization of custom objects with export's serialization. After some discussion (more background [here](https://docs.google.com/document/d/1lJfxakmgeoEt50inWZ53MdUtOSa_0ihwCuPy_Ak--wc/edit)), we settled on a way for users to register their custom object's serializer/deserializer functions.
Since TorchScript's `.def_pickle` already exists for [registering custom classes](https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html), and `tensorrt.ICudaEngine` already contains a `.def_pickle` implementation, we'll start off by reusing the existing framework and integrating it with export's serialization.
TorchScript's `.def_pickle` requires users to register two functions, which end up being the `__getstate__` and `__setstate__` methods on the class. The semantics of `__getstate__` and `__setstate__` in TorchScript are equivalent to that of Python pickle modules. This is then registered using pybind's `py::pickle` function [here](https://www.internalfb.com/code/fbsource/[f44e048145e4697bccfaec300798fce7daefb858]/fbcode/caffe2/torch/csrc/jit/python/script_init.cpp?lines=861-916) to be used with Python's pickle to initialize a ScriptObject with the original class, and set the state back to what it used to be.
I attempted to call `__getstate__` and `__setstate__` directly, but I couldn't figure out how to initial the object to be called with `__setstate__` in python. One option would be to create a `torch._C.ScriptObject` and then set the class and call `__setstate__`, but there is no constructor initialized for ScriptObjects. Another option would be to construct an instance of the serialized class itself, but if the class constructor required arguments, I wouldn't know what to initialize it with. In ScriptObject's `py::pickle` registration it directly creates the object [here](https://www.internalfb.com/code/fbsource/[f44e048145e4697bccfaec300798fce7daefb858]/fbcode/caffe2/torch/csrc/jit/python/script_init.cpp?lines=892-906), which is why I was thinking that just directly using Python's `pickle` will be ok since it is handled here.
So, what I did is that I check if the object is pickle-able, meaning it contains `__getstate__` and `__setstate__` methods, and if so, I serialize it with Python's pickle. TorchScript does have its own implementation of [pickle/unpickle](https://www.internalfb.com/code/fbsource/[59cbc569ccbcaae0db9ae100c96cf0bae701be9a][history]/fbcode/caffe2/torch/csrc/jit/serialization/pickle.h?lines=19%2C82), but it doesn't seem to have pybinded functions callable from python.
A question is -- is it ok to combine this pickle + json serialization?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107666
Approved by: https://github.com/gmagogsfm
Added the following APIs:
```
def save(
ep: ExportedProgram,
f: Union[str, pathlib.Path, io.BytesIO],
extra_files: Optional[Dict[str, Any]] = None,
opset_version: Optional[Dict[str, int]] = None,
) -> None:
"""
Saves a version of the given exported program for use in a separate process.
Args:
ep (ExportedProgram): The exported program to save.
f (str): A file-like object (has to implement write and flush)
or a string containing a file name.
extra_files (Optional[Dict[str, Any]]): Map from filename to contents
which will be stored as part of f.
opset_version (Optional[Dict[str, int]]): A map of opset names
to the version of this opset
"""
def load(
f: Union[str, pathlib.Path, io.BytesIO],
extra_files: Optional[Dict[str, Any]] = None,
expected_opset_version: Optional[Dict[str, int]] = None,
) -> ExportedProgram:
"""
Loads an ExportedProgram previously saved with torch._export.save
Args:
ep (ExportedProgram): The exported program to save.
f (str): A file-like object (has to implement write and flush)
or a string containing a file name.
extra_files (Optional[Dict[str, Any]]): The extra filenames given in
this map would be loaded and their content would be stored in the
provided map.
expected_opset_version (Optional[Dict[str, int]]): A map of opset names
to expected opset versions
Returns:
An ExportedProgram object
"""
```
Example usage:
```
# With buffer
buffer = io.BytesIO()
torch._export.save(ep, buffer)
buffer.seek(0)
loaded_ep = torch._export.load(buffer)
# With file
with tempfile.NamedTemporaryFile() as f:
torch._export.save(ep, f)
f.seek(0)
loaded_ep = torch._export.load(f)
# With Path
with TemporaryFileName() as fname:
path = pathlib.Path(fname)
torch._export.save(ep, path)
loaded_ep = torch._export.load(path)
# Saving with extra files
buffer = io.BytesIO()
save_extra_files = {"extra.txt": "moo"}
torch._export.save(ep, buffer, save_extra_files)
buffer.seek(0)
load_extra_files = {"extra.txt": ""}
loaded_ep = torch._export.load(buffer, extra_files)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107309
Approved by: https://github.com/avikchaudhuri, https://github.com/gmagogsfm, https://github.com/tugsbayasgalan
Since constrain_as_size has been fixed, I tried serializing it, but ran into some issues.
Notably, after each `.transform` call, I added a helper `_get_updated_range_constraints` to update the range constrains list. This is because when we retrace in a pass, the symbolic values being used changes, so we need to update this dictionary.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107386
Approved by: https://github.com/avikchaudhuri, https://github.com/zhxchen17
Added a version number to the schema for BC issues. We will add this number to the serialized ExportedProgram and then when deserializing, if the number does not match up with the existing deserializer, we will error. We should update the number of there are any major changes to the schema.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107420
Approved by: https://github.com/zhxchen17
Summary: Sometimes the graph that is being serialized contains nodes with side effects + no users (ex. out variants of operators), so we don't want to eliminate those when deserializing.
Test Plan: CI
Differential Revision: D47735009
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105875
Approved by: https://github.com/ydwu4
Solving #105242.
During export, the exported function's signature changes multiple times. Suppose we'd like to export f as shown in following example:
```python
def f(arg1, arg2, kw1, kw2):
pass
args = (arg1, arg2)
kwargs = {"kw2":arg3, "kw1":arg4}
torch.export(f, args, kwargs)
```
The signature changes mutiple times during export process in the following order:
1. **gm_torch_level = dynamo.export(f, *args, \*\*kwargs)**. In this step, we turn all kinds of parameters such as **postional_only**, **var_positioinal**, **kw_only**, and **var_kwargs** into **positional_or_kw**.It also preserves the positional and kword argument names in original function (i.e. f in this example) [here](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/export.py#L546C13-L546C27). The order of kwargs will be the **key order** of kwargs (after python 3.6, the order is the insertion of order of keys) instead of the original function signature and the order is baked into a _orig_args varaible of gm_torch_level's pytree info. So we'll have:
```python
def gm_torch_level(arg1, arg2, kw2, kw1)
```
Such difference is acceptable as it's transparent to users of export.
2. **gm_aot_export = aot_export_module(gm_torch_level, pos_or_kw_args)**. In this step, we need to turn kwargs into positional args in the order of how gm_torch_level expected, which is stored in _orig_args. The returned gm_aot_export has the graph signature of flat_args, in_spec = pytree.tree_flatten(pos_or_kw_args):
``` python
flat_args, _ = pytree.tree_flatten(pos_or_kw_args)
def gm_aot_export(*flat_args)
```
3. **exported_program(*args, \*\*kwargs)**. The epxorted artifact is exported_program, which is a wrapper over gm_aot_export and has the same calling convention as the original function "f". To do this, we need to 1. specialize the order of kwargs into pos_or_kw_args and 2. flatten the pos_or_kw_args into what gm_aot_export expected. We can combine the two steps into one with :
```python
_, in_spec = pytree.tree_flatten((args, kwargs))
# Then during exported_program.__call__(*args, **kwargs)
flat_args = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
```
, where kwargs is treated as a normal pytree whose keyorder is preserved in in_spec.
Implementation-wise, we treat _orig_args in dynamo exported graph module as single source of truth and kwags are ordered following it.
Test plan:
See added tests in test_export.py.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105337
Approved by: https://github.com/angelayi, https://github.com/tugsbayasgalan
v2 of https://github.com/pytorch/pytorch/pull/102125 because of git issues
corresponding deserialization diff: https://github.com/pytorch/pytorch/pull/102716
Implementing serialization of the exported program to a python dataclass, and then from that dataclass to json. This is split into a couple of sections:
- `serialize(ep: ep.ExportedProgram, opset_version: Dict[str, int]) -> Tuple[bytes, bytes]` -- takes an exported program object, a dictionary mapping opset namespaces to versions, and returns the serialized exported program in bytes, and separately the state dict serialized in bytes
- `GraphModuleSerializer` class that serializes torch.fx.GraphModule
to the schema.GraphModule dataclass
- `ExportedProgramSerializer` class that serializes torch._export.exported_program.ExportedProgram to the schema.ExportedProgram dataclass
Serialization TODOs:
- [x] pytree spec: https://github.com/pytorch/pytorch/pull/102577
- [ ] higher order ops
- [ ] node metadata (specifically nn_module_stack/source_fn)
- [ ] constraints
- [ ] graph module metadata
The tests are not super comprehensive, but that's because I think it'll be better tested + easier to test once deserialization is implemented.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102707
Approved by: https://github.com/avikchaudhuri, https://github.com/zhxchen17