Constant time access of first value in collection. This is a constant time operation instead of converting the item to a list to get the first item which is linear. The rule is turned on which automatically autofixes and enforces this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115507
Approved by: https://github.com/malfet
Changes:
1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
Summary:
Fixed nn_module_stack dynamo produced by symbolic trace to align with the nn_module_stack metadata produced by dynamo. The key should be the module path, with the value being a unique name, and the type. Something like: `{'L__self___one_module': ("L['self'].one_module", <class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>)}`
This was causing some tests to fail when using export + the old quantization flow (prepare_fx calls symbolic_trace).
Test Plan: D51534471 `buck2 run @//mode/dev-nosan //executorch/backends/xnnpack/test:test_xnnpack_quantized -- -r "test_xnnpack_leaky_relu"`
Differential Revision: D51539118
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114422
Approved by: https://github.com/JacobSzwejbka, https://github.com/jerryzh168
Changes:
1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
Summary:
Traditionally when user want to update the arguments for an FX node, the only way is to call the setter of .args property on nodes. This may be problematic when we insert a lot of arguments. Because of the semantics of the setter method, it has a worst case O(n) complexity.
Adding a new insert_arg provides us two benefits:
1. The operation is guaranteed to be O(1) cost.
2. User can express the intentation more directly, instead of writing code like `node.args = (arg,) + node.args`
Test Plan: caffe2/test:fx -- -r test_insert_arg
Reviewed By: suo
Differential Revision: D50574435
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111974
Approved by: https://github.com/angelayi
Did some easy fixes from enabling TRY200. Most of these seem like oversights instead of intentional. The proper way to silence intentional errors is with `from None` to note that you thought about whether it should contain the cause and decided against it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111496
Approved by: https://github.com/malfet
A resubmit of https://github.com/pytorch/pytorch/pull/108447. Copy over the descriptions:
This is a follow-up of the discussion in https://github.com/pytorch/pytorch/pull/108356, where we want to repalce source_fn with source_fn_stack
Before this PR, for the following example:
```python
backend = EagerAndRecordGraphs()
@torch.compile(backend=backend, fullgraph=True)
def cond_f(pred, pred2, x, y):
def true_fn(pred2, x, y):
return x + y
def false_fn(pred2, x, y):
def true_fn2(x, y):
return x.sin() - y.cos()
def false_fn2(x, y):
return x.cos() - y.sin()
return control_flow.cond(pred2, true_fn2, false_fn2, (x, y))
return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y))
```
The graph captured is shown below:
```python
class GraphModule(torch.nn.Module):
def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
l_pred_ = L_pred_
l_pred2_ = L_pred2_
l_x_ = L_x_
l_y_ = L_y_
cond_true_1 = self.cond_true_1
cond_false_1 = self.cond_false_1
cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]); l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None
return (cond,)
class GraphModule(torch.nn.Module):
def forward(self, l_pred2_, l_x_, l_y_):
add = l_x_ + l_y_; l_x_ = l_y_ = None
return add
class GraphModule(torch.nn.Module):
def forward(self, l_pred2_, l_x_, l_y_):
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]); l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None
return cond
class GraphModule(torch.nn.Module):
def forward(self, l_x_, l_y_):
sin = l_x_.sin(); l_x_ = None
cos = l_y_.cos(); l_y_ = None
sub = sin - cos; sin = cos = None
return sub
class GraphModule(torch.nn.Module):
def forward(self, l_x_, l_y_):
cos = l_x_.cos(); l_x_ = None
sin = l_y_.sin(); l_y_ = None
sub = cos - sin; cos = sin = None
return sub
```
the source_fn for inner cond, sin, cos will be a (name, target) tuple:
```
('cond', <torch._ops.HigherOrderOperator object at xxx>)
('sin', 'sin')
('cos', 'cos')
('sub'. <built-in function sub>)
```
After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list.
```
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')]
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)]
```
Test Plan:
See added tests in test_higher_order_ops.py and modify existing test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108595
Approved by: https://github.com/angelayi, https://github.com/zou3519
Requested from @tugsbayasgalan: we want dynamo to preserve some FX node metadata when we trace `GraphModule`s (`nn_module_stack`, `source_fn`, `stack_trace`). This is helpful for the case when we export an aten-level `GraphModule`, add some (possibly non-torch or non-aten) ops, and we want to transform the graph back into an aten-level graph. Without preserving metadata, future passes that look at metadata (e.g. quantization passes) won't work.
This feature also has the additional benefit of being able to preserve origin line of code when `print_readable`'ing a `GraphModule`. This is helpful when debugging graphs that have passed through dynamo several times.
The added unit test demonstrates the added functionality of this PR.
~This PR is currently a proof-of-concept implementation that shows that preserving node metadata across dynamo is possible.~ This PR preserves node metadata across dynamo by doing the following:
- ~inject a counter variable into the `GraphModule` source code, which is incremented every time a node is run~
- Construct a line number -> node index map in `GraphModule` as the source code is being generated.
- pass a list of node metadata and the line number map to dynamo's bytecode analyzer
- ~dynamo traces the counter as a `ConstantVariable`, so when we create a new proxy, we can determine which original node index this proxy corresponds by looking at the value of the traced counter~
- When we create a new proxy, get the current instruction's line number, and get the node index using the line number map
- index into the original node metadata ~using the counter variable's tracked value.~
~Some things that should be addressed off the top of my head:~
- ~Is this feature even desirable? (Do we really want Dynamo to have special behavior for `GraphModules`? Should we expect users to re-export `GraphModules`?)~
- ~Is there a better approach than to use a counter? We considered using node names, line numbers, and assuming that proxies are created in the same order as the nodes, but each of these 3 have shortcomings. For node names, we only have access to new node names, not the old ones. Using line number is fragile. The third is problematic since not all created nodes go through `create_proxy` (e.g. inputs). We currently generate a line number to node index map when the `GraphModule`'s code is generated.~
- ~What's the best way to send data across the "CPython gap"? That is, it is not obvious how to cleanly pass data from dynamo's `eval_frame.py:_TorchDynamoContext.__call__` to `symbolic_convert.py:InstructionTranslatorBase.__init__`. In this PR, we use a global.~
Differential Revision: [D49257108](https://our.internmc.facebook.com/intern/diff/D49257108)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107067
Approved by: https://github.com/jansel
Requested from @tugsbayasgalan: we want dynamo to preserve some FX node metadata when we trace `GraphModule`s (`nn_module_stack`, `source_fn`, `stack_trace`). This is helpful for the case when we export an aten-level `GraphModule`, add some (possibly non-torch or non-aten) ops, and we want to transform the graph back into an aten-level graph. Without preserving metadata, future passes that look at metadata (e.g. quantization passes) won't work.
This feature also has the additional benefit of being able to preserve origin line of code when `print_readable`'ing a `GraphModule`. This is helpful when debugging graphs that have passed through dynamo several times.
The added unit test demonstrates the added functionality of this PR.
~This PR is currently a proof-of-concept implementation that shows that preserving node metadata across dynamo is possible.~ This PR preserves node metadata across dynamo by doing the following:
- ~inject a counter variable into the `GraphModule` source code, which is incremented every time a node is run~
- Construct a line number -> node index map in `GraphModule` as the source code is being generated.
- pass a list of node metadata and the line number map to dynamo's bytecode analyzer
- ~dynamo traces the counter as a `ConstantVariable`, so when we create a new proxy, we can determine which original node index this proxy corresponds by looking at the value of the traced counter~
- When we create a new proxy, get the current instruction's line number, and get the node index using the line number map
- index into the original node metadata ~using the counter variable's tracked value.~
~Some things that should be addressed off the top of my head:~
- ~Is this feature even desirable? (Do we really want Dynamo to have special behavior for `GraphModules`? Should we expect users to re-export `GraphModules`?)~
- ~Is there a better approach than to use a counter? We considered using node names, line numbers, and assuming that proxies are created in the same order as the nodes, but each of these 3 have shortcomings. For node names, we only have access to new node names, not the old ones. Using line number is fragile. The third is problematic since not all created nodes go through `create_proxy` (e.g. inputs). We currently generate a line number to node index map when the `GraphModule`'s code is generated.~
- ~What's the best way to send data across the "CPython gap"? That is, it is not obvious how to cleanly pass data from dynamo's `eval_frame.py:_TorchDynamoContext.__call__` to `symbolic_convert.py:InstructionTranslatorBase.__init__`. In this PR, we use a global.~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107067
Approved by: https://github.com/jansel
Add similar semantics for creating a buffer object similar to creating a parameter. This is done by introducing a new `Buffer` class that can be used for type disambiguation. The underlying functionality of registering a buffer remains the same as the `register_buffer` method has not been changed. The `persistent` parameter in the `Buffer` type is to indicate whether a buffer object should be persistent or not. Other non-test changes have to do with getting the new `Buffer` type recognized by inductor and dynamo. Remaining changes are test changes to make sure that the `Buffer` type can be used as a drop in replacement for `register_buffer` as it just leads to `register_buffer` being called. The addition of this new functionality still allows for normal tensors to be used as buffers so these changes are intended to be backwards compatible.
Fixes#35735
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104069
Approved by: https://github.com/mikaylagawarecki
Previously, you'd get `<eval_with_key>.0`; now you get `<eval_with_key>.0 from /data/users/ezyang/b/pytorch/test/dynamo/test_misc.py:5683 in forward`
I used to do this with globals, but now I do it with a `co_fields` parameter that's plumbed around, because putting things in globals has implications(TM). Happy to bikeshed on the `co_fields` structure.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103885
Approved by: https://github.com/albanD
Summary:
When we pickle/unpickle graph module in multipy, we would lost modules/attributes that are not referred in the graph. This is because when unpickle fx graph module, we use the stored `__dict__` and the fx graph to create a new graph module. In GraphModule init, we drop any attribute that is not referred in the graph.
This behavior is not ideal because we actually expect a graph module that's exactly the same after unpickling.
Test Plan:
```
buck test mode/opt caffe2/test:fx -- test_preserve_unused_attr_after_unpickle
Tests finished: Pass 1. Fail 0. Fatal 0. Skip 0. Build failure 0
```
Differential Revision: D46976230
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104115
Approved by: https://github.com/houseroad
Previously, you'd get `<eval_with_key>.0`; now you get `<eval_with_key>.0 from /data/users/ezyang/b/pytorch/test/dynamo/test_misc.py:5683 in forward`
I used to do this with globals, but now I do it with a `co_fields` parameter that's plumbed around, because putting things in globals has implications(TM). Happy to bikeshed on the `co_fields` structure.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103885
Approved by: https://github.com/albanD
Summary:
# Context
In TorchRec's train pipeline, we need to fx trace a module to analyze the arguments on the forward call. In order to do this, we need to preserve some sort of meaning with each argument (a key or name of sorts that lets us identify the argument).
The issue is, when you use concrete args, internally, fx will unflatten the arg into it's constituents (to locate PHs).
Given a function that looks like this:
```
def process(batch: Dict[str, torch.Tensor]):
....
symbolic_trace(process, concrete_args: {"batch": {"f1": PH, "f2": PH}})
# function will be rewritten to look like:
def process(batch_1, batch_2): # batch_1 -> "f1", batch_2->"f2"
...
```
When you traverse through the nodes of the graph, the names of the argument nodes to the function are batch_1 and batch_2. **This doesn't mean anything to the user who is fx tracing.** There isn't anything indicating that batch_1 corresponds to key "f1" in the batch input.
# Solution
When fx sees a "PH", it creates a proxy node.
The user does not have direct access to proxy creation, but only through the PH structure.
Attach a piece of metadata, `ph_key`, to the PH when you set it in the concrete args, it will get passed into proxy + node creation. So when you traverse the graph, this metadata sticks onto the node as an attribute. This way you have a way of tagging that "batch_1" as "f1".
Test Plan: added a unit test
Reviewed By: dstaay-fb
Differential Revision: D44947653
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102195
Approved by: https://github.com/PaliC
Summary: Change placeholder check from singleton to instanceof PHBase so you can create your own PH class with metadata
Test Plan: added unit test
Reviewed By: joshuadeng
Differential Revision: D46085128
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102008
Approved by: https://github.com/PaliC
Added helper functions to match nodes in the graph that are decomposed from their source (leaf modules, or functional ops), as a result of dynamo tracing.
`get_source_partitions(graph: torch.fx.Graph, wanted_sources: List[Any]) -> Dict[Any, SourcePartition]`
Args:
* graph: The graph we want to partition
* wanted_sources: List of sources of nodes that were decomposed from this source. This can be a function (ex. torch.nn.functional.linear) or a leaf module type (ex. torch.nn.Linear)
Returns:
* Dictionary mapping sources (ex. torch.nn.modules.linear.Linear) to a list of SourcePartitions that correspond to the list of nodes that were flattened from a module of that type.
```
@dataclass
class SourcePartition():
# Nodes in a particular partition
nodes: List[Node]
# Module type
module_type: Type
# Nodes in the graph that are needed as inputs to the partition
input_nodes: List[Node] = field(default_factory=list)
# Nodes in the partition that are being used by nodes outside of the partition
output_nodes: List[Node] = field(default_factory=list)
# Parameters that are being used
params: List[str] = field(default_factory=list)
```
Example:
Original:
```
x -> linear -> linear -> relu -> linear
```
Traced graph:
```
.graph():
%arg0 : [#users=1] = placeholder[target=arg0]
%_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
%t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant0,), kwargs = {})
%_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
%addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0, %t_default), kwargs = {})
%_param_constant0_1 : [#users=1] = get_attr[target=_param_constant0]
%t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant0_1,), kwargs = {})
%_param_constant1_1 : [#users=1] = get_attr[target=_param_constant1]
%addmm_default_1 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1_1, %addmm_default, %t_default_1), kwargs = {})
%relu_default : [#users=1] = call_function[target=torch.ops.aten.relu.default](args = (%addmm_default_1,), kwargs = {})
%_param_constant2 : [#users=1] = get_attr[target=_param_constant2]
%t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant2,), kwargs = {})
%_param_constant3 : [#users=1] = get_attr[target=_param_constant3]
%addmm_default_2 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant3, %relu_default, %t_default_2), kwargs = {})
return [addmm_default_2]
```
Result of `get_module_partitions`:
```
{<class 'torch.nn.modules.linear.Linear'>: [
ModulePartition(nodes=[_param_constant0, t_default, _param_constant1, addmm_default], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[arg0], output_nodes=[addmm_default], params=["_param_constant0", "_param_constant1"]),
ModulePartition(nodes=[_param_constant0_1, t_default_1, _param_constant1_1, addmm_default_1], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[addmm_default], output_nodes=[addmm_default_1], params=["_param_constant0_1", "_param_constant1_1"]),
ModulePartition(nodes=[_param_constant2, t_default_2, _param_constant3, addmm_default_2], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[relu_default], output_nodes=[addmm_default_2], params=["_param_constant2", "_param_constant3"])],
<class 'torch.nn.modules.activation.ReLU'>: [
ModulePartition(nodes=[relu_default], module_type=<class 'torch.nn.modules.activation.ReLU'>, input_nodes=[addmm_default_1], output_nodes=[relu_default], params=[])]}
```
Also added helper function to check if two module partitions are connected:
`check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98628
Approved by: https://github.com/cccclai
Summary: There're some customized functions that we would also like to keep during eliminate dead code pass. Add a function to help us to do.
Test Plan: Added a unit test
Differential Revision: D44273630
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97288
Approved by: https://github.com/houseroad
Summary:
Currently torch.fx support Modules with input of namedtuple/dataclass, return as namedtuple, but does not allow Module.forward to return a dataclass, running `test_trace_return_dataclass` without this change will have following error:
NotImplementedError: argument of type: <class 'test_fx.TestFX.test_trace_return_dataclass.<locals>.MyOutput'>
File "test_trace_return_dataclass
traced_graph = symbolic_trace(module).graph
File "test/__fx__/fx#link-tree/torch/fx/_symbolic_trace.py", line 1114, in symbolic_trace
graph = tracer.trace(root, concrete_args)
File "test/__fx__/fx#link-tree/torch/fx/_symbolic_trace.py", line 783, in trace
(self.create_arg(fn(*args)),),
File "test/__fx__/fx#link-tree/torch/fx/_symbolic_trace.py", line 378, in create_arg
return super().create_arg(a)
File "test/__fx__/fx#link-tree/torch/fx/proxy.py", line 269, in create_arg
raise NotImplementedError(f"argument of type: {type(a)}")
this diff handle dataclass type.
Test Plan:
buck test @//mode/opt @//mode/inplace //caffe2/test:fx -- test_trace_
graph():
%d : torch.Tensor [#users=1] = placeholder[target=d]
%my_output : [#users=1] = call_function[target=test_fx.MyOutput](args = (), kwargs = {foo: %d, bar: %d})
return my_output
Differential Revision: D44916519
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99576
Approved by: https://github.com/suo
Twice this week I have had people confuse "operator defined with Python
operator registration aka torch.library" and "PyOperator which is used
to define control flow operators and other operators that cannot be
represented in JIT schema." Renaming PyOperator for clarity.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97493
Approved by: https://github.com/SherlockNoMad
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
To match nodes within the graph, the matcher currently flattens the arguments and compares each argument against each other. However, if it believes that a list input contains all literals, it will not flatten the list and will instead compare the list directly against each other. It determines if a list is a literal by checking if the first element is a node. However this doesn't work in some cases (like the test cases I added).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94375
Approved by: https://github.com/SherlockNoMad