The main thrust of the initial effort here was to capture `register_hook` calls on tensors in compile regions. The first part of this was done in https://github.com/pytorch/pytorch/pull/108903 wherein we added support for register_hook input tensors.
The distinction between input and intermediary is due to implementation differences.
There are 2 kinds of hooks:
1) Hooks on objects with sources (inputs, params)
2) Hooks on objects w/o sources (intermediaries, and outputs).
Note: As outputs can be made simple by how dynamo handles residuals, they could actually be handled as if they were inputs, but, for the sake of this PR, we will refer to hooks as either hooks on inputs (sourced), or hooks on intermediaries (not sourced).
**The plan:**
For tensors w/ a source: (The PR above)
We record registered hooks, store them as a global, and associate them with the tensor in residuals. This means that when dynamo goes to create the frame, where we produce bytecode to stitch together our PT2 modified bytecode with the original eager code, we call register_hook. This registration of hooks in residuals is sound because (a) it happens right after a Pt2 frame region ends and (b) we know that the tensor is alive in f_locals, f_globals, or a module in the users invoking frame. This means we can soundly know it will be around to invoke register_hook on. As long as we guard on the identity of the lifted function, this is sound to do.
For tensors w/o a source: (This PR)
Ostensibly, the most correct and complete solution would be to smuggle hooks into a runtime wrapper in aot_autograd, where all the items the hooks close over are lifted to inputs as necessary and passed alongside the user provided function. This is necessary so that we can properly trace out and capture all the mutations within the user defined hook at backwards time.
This is too complicated - so, we limited the scope of this initial PR to a simple subset of hooks:
- Hooks must have a source (be known to us already, not a lambda or intermediary defined function)
- We must be tracing under compiled autograd
**The flow**:
We use the HOP added in https://github.com/pytorch/pytorch/pull/109690/files, referred to as the HOP below.
1) We intercept register_hook calls and wrap the user defined fn in the HOP
2) We write a `_register_hook_trampoline` to the graph that is a local no-arg function that is invoked as a call_function in the dynamo graph
3) aot_autograd inlines through it during its trace, and sees the HOP
4) the HOP preserves itself in the graph - it does not get traced into
5) During backwards, compiled_autograd installs the HOP under a hook call
6) When compiled_autograd enters compilation over its generated graph, dynamo traces the contents of the hook
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109537
Approved by: https://github.com/ezyang
The strategy in this PR is pretty straightforward.
There are 2 kinds of hooks:
1) Hooks on objects with sources (inputs, params)
2) Hooks on objects w/o sources (intermediaries, and outputs).
Note: As outputs can be made simple by how dynamo handles residuals, they could actually be handled as if they were inputs, but, for the sake of this PR, we will refer to hooks as either hooks on inputs (sourced), or hooks on intermediaries (not sourced).
The plan:
**For tensors w/ a source:**
We record registered hooks, store them as a global, and associate them with the tensor in residuals. This means that when dynamo goes to create the frame, where we produce bytecode to stitch together our PT2 modified bytecode with the original eager code, we call `register_hook`. This registration of hooks in residuals is sound because (a) it happens right after a Pt2 frame region ends and (b) we know that the tensor is alive in f_locals, f_globals, or a module in the users invoking frame. This means we can soundly know it will be around to invoke `register_hook` on. As long as we guard on the identity of the lifted function, this is sound to do.
**For tensors w/o a source:**
Graph break - we will support this in a subsequent PR
**Handles:**
An interesting new component here is the creation of a `STORE_FAST `->`LOAD_FAST` associated with the handle, the return result of `register_hook`. If the user code stored the result of `register_hook` in a handle, we need to honor that. We do so by interceding into `STORE_FAST`, and recording the name of the local variable as directed by user code. We then honor that same name in the reconstructed bytecode. If the user did not store a hook, we merely pop the produced value to preserve the stack.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108903
Approved by: https://github.com/ezyang
ghstack dependencies: #108846, #109092
The strategy for supporting functools partials is relatively straightforward.
There are 2 cases we need to support:
**1) Functools partials as input**
In this case, we are first seeing the functools partial and it is guaranteed to have a source. As such, the args, keywords, and func of the functools partial are passed through VariableBuilder. As this is the first time we are seeing these objects (as it is an input), we re-enter VariableBuilder with a source referencing the args, keywords, and func as attributes of the input to produce:
- func: A callable VariableTracker (UDF, TorchVariable, etc) depending on the value of `func`
- args: List[VariableTracker] - note, not ListVariableTracker!
- keywords: Dict[str, VariableTracker]
A major benefit of this structure is that it very elegantly matches the args to `call_function`.
We then compose a FunctoolsPartialVariable from the VariableTrackers made above.
**2) Functools partials created within compile**
In this case, we already have all the args as known VTs, and thus just compose a FunctoolsPartialVariable as we do for case (1).
For both (1) and (2) - we propagate all guards from the func, args, and keyword VTs to the FunctoolsPartialVariable
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108846
Approved by: https://github.com/ezyang, https://github.com/jansel
When inlining a function which loads a closure, its direct parent may not load that closure. So we cannot find the closure name in parent's symbolic locals. In this PR, we fix it by recursively searching the parent instruction translator stack to resolve the closure.
**Background**
When developing https://github.com/pytorch/pytorch/pull/105679, this corner case is triggered. A small repro is added in the test of this pr, where outer is loaded by deep2 but not by deep.
```python
def test_inline_closure_not_loaded_by_parent(self):
def outer(a):
return a + 1
def indirect(x):
return direct(x)
def direct(x):
def deep2(c):
return outer(c)
def deep(c):
return deep2(c)
return deep(x)
x = torch.randn(3)
eager = indirect(x)
counter = CompileCounter()
compiled = torch._dynamo.optimize(counter)(indirect)(x)
```
Running the test, we have the following error before the PR:
```
Traceback (most recent call last):
File "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6584, in test_inline_closure_not_loaded_by_parent
compiled = torch._dynamo.optimize(counter)(indirect)(x)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 321, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 481, in catch_errors
return callback(frame, cache_size, hooks, frame_state)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 543, in _convert_frame
result = inner_convert(frame, cache_size, hooks, frame_state)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 130, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 362, in _convert_frame_assert
return _compile(
File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 194, in time_wrapper
r = func(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 531, in _compile
raise InternalTorchDynamoError(str(e)).with_traceback(e.__traceback__) from None
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 432, in _compile
out_code = transform_code_object(code, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 417, in transform
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2067, in run
super().run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 724, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 688, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1116, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 562, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 261, in call_function
return super().call_function(tx, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2172, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2279, in inline_call_
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 724, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 688, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1116, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 562, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2172, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2279, in inline_call_
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 724, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 688, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1116, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 562, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2172, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2227, in inline_call_
sub_locals, closure_cells = func.bind_args(parent, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 471, in bind_args
result[name] = parent.symbolic_locals[name]
torch._dynamo.exc.InternalTorchDynamoError: outer
from user code:
File "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6570, in indirect
return direct(x)
File "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6579, in direct
return deep(x)
File "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6577, in deep
return deep2(c)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
To execute this test, run the following from the base repo dir:
python test/dynamo/test_misc.py -k test_inline_closure_not_loaded_by_parent
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
---------------------------------------------------------------------------------------------------------------------------- Captured stdout call -----------------------------------------------------------------------------------------------------------------------------
frames [('total', 1)]
inline_call []
---------------------------------------------------------------------------------------------------------------------------- Captured stderr call -----------------------------------------------------------------------------------------------------------------------------
[2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping __init__ /home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py
[2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping __enter__ /home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py
[2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping helper /home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py
[2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping __init__ /home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py
[2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping __enter__ /home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py
[2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping enable_dynamic /home/yidi/local/pytorch/torch/_dynamo/eval_frame.py
[2023-08-02 15:48:36,561] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing indirect /home/yidi/local/pytorch/test/dynamo/test_misc.py:6569
TRACE starts_line indirect /home/yidi/local/pytorch/test/dynamo/test_misc.py:6569
def indirect(x):
[2023-08-02 15:48:36,591] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['x'] (3,) [<DimDynamic.STATIC: 2>] [None]
TRACE starts_line indirect /home/yidi/local/pytorch/test/dynamo/test_misc.py:6570
return direct(x)
[2023-08-02 15:48:36,594] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_DEREF direct []
[2023-08-02 15:48:36,594] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST x [UserFunctionVariable()]
[2023-08-02 15:48:36,594] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL_FUNCTION 1 [UserFunctionVariable(), TensorVariable()]
[2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] INLINING <code object direct at 0x7fbe4d366810, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6572>
TRACE starts_line direct /home/yidi/local/pytorch/test/dynamo/test_misc.py:6572 (inline depth: 1)
def direct(x):
TRACE starts_line direct /home/yidi/local/pytorch/test/dynamo/test_misc.py:6573 (inline depth: 1)
def deep2(c):
[2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CLOSURE outer []
[2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] TRACE BUILD_TUPLE 1 [InlinedClosureVariable()]
[2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST <code object deep2 at 0x7fbe4d3666b0, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6573> [TupleVariable()]
[2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST MiscTests.test_inline_closure_not_loaded_by_parent.<locals>.direct.<locals>.deep2 [TupleVariable(), ConstantVariable(code)]
[2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] TRACE MAKE_FUNCTION 8 [TupleVariable(), ConstantVariable(code), ConstantVariable(str)]
[2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_DEREF deep2 [NestedUserFunctionVariable()]
TRACE starts_line direct /home/yidi/local/pytorch/test/dynamo/test_misc.py:6576 (inline depth: 1)
def deep(c):
[2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CLOSURE deep2 []
[2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE BUILD_TUPLE 1 [NewCellVariable()]
[2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST <code object deep at 0x7fbe4d366760, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6576> [TupleVariable()]
[2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST MiscTests.test_inline_closure_not_loaded_by_parent.<locals>.direct.<locals>.deep [TupleVariable(), ConstantVariable(code)]
[2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE MAKE_FUNCTION 8 [TupleVariable(), ConstantVariable(code), ConstantVariable(str)]
[2023-08-02 15:48:36,598] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_FAST deep [NestedUserFunctionVariable()]
TRACE starts_line direct /home/yidi/local/pytorch/test/dynamo/test_misc.py:6579 (inline depth: 1)
return deep(x)
[2023-08-02 15:48:36,598] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST deep []
[2023-08-02 15:48:36,598] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST x [NestedUserFunctionVariable()]
[2023-08-02 15:48:36,598] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL_FUNCTION 1 [NestedUserFunctionVariable(), TensorVariable()]
[2023-08-02 15:48:36,598] torch._dynamo.symbolic_convert: [DEBUG] INLINING <code object deep at 0x7fbe4d366760, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6576>
TRACE starts_line deep /home/yidi/local/pytorch/test/dynamo/test_misc.py:6576 (inline depth: 2)
def deep(c):
TRACE starts_line deep /home/yidi/local/pytorch/test/dynamo/test_misc.py:6577 (inline depth: 2)
return deep2(c)
[2023-08-02 15:48:36,599] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_DEREF deep2 []
[2023-08-02 15:48:36,599] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST c [NestedUserFunctionVariable()]
[2023-08-02 15:48:36,599] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL_FUNCTION 1 [NestedUserFunctionVariable(), TensorVariable()]
[2023-08-02 15:48:36,599] torch._dynamo.output_graph: [DEBUG] restore_graphstate: removed 0 nodes
[2023-08-02 15:48:36,599] torch._dynamo.symbolic_convert: [DEBUG] FAILED INLINING <code object deep at 0x7fbe4d366760, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6576>
[2023-08-02 15:48:36,599] torch._dynamo.output_graph: [DEBUG] restore_graphstate: removed 0 nodes
[2023-08-02 15:48:36,599] torch._dynamo.symbolic_convert: [DEBUG] FAILED INLINING <code object direct at 0x7fbe4d366810, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6572>
[2023-08-02 15:48:36,599] torch._dynamo.output_graph: [DEBUG] restore_graphstate: removed 0 nodes
```
Test Plan:
add new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106491
Approved by: https://github.com/williamwen42, https://github.com/jansel, https://github.com/zou3519
The main complexity comes from the __init__ function of Dataclass variables which look something like this
```
[2023-07-10 05:01:29,548] torch._dynamo.symbolic_convert: [DEBUG] INLINING <code object __init__ at 0x7f7015154450, file "<string>", line 2>
3 0 LOAD_FAST 1 (b)
2 LOAD_FAST 0 (self)
4 STORE_ATTR 0 (b)
4 6 LOAD_FAST 2 (named_tensors)
8 LOAD_DEREF 0 (_HAS_DEFAULT_FACTORY)
10 IS_OP 0
12 POP_JUMP_IF_FALSE 20
14 LOAD_DEREF 1 (_dflt_named_tensors)
16 CALL_FUNCTION 0
18 JUMP_FORWARD 2 (to 22)
>> 20 LOAD_FAST 2 (named_tensors)
>> 22 LOAD_FAST 0 (self)
24 STORE_ATTR 1 (named_tensors)
26 LOAD_CONST 0 (None)
28 RETURN_VALUE
```
There are multiple issues
* VariableBuilder call in functions.py was wrong. We were calling *options as args.
* We were not setting source while tracking the new object. This led to no source for Dataclass variable, which has some new variables in its closures as seen in the above bytecode.
* There is IS_OP in above bytecode, which brings more cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104840
Approved by: https://github.com/jansel
Fix https://github.com/pytorch/pytorch/issues/99639 by handling the case in `InliningInstructionTranslator`'s `LOAD_CLOSURE` definition when the requested cell is not in `self.closure_cells`.
My intuition is that the behavior of `LOAD_DEREF` and `STORE_DEREF` on a cell/freevar should not depend on whether or not we called `LOAD_CLOSURE` (that is, we shouldn't create a new cell var in `LOAD_CLOSURE` like in https://github.com/pytorch/pytorch/pull/101357). But we need a way to push cells created by the inlined function that were not present in the caller - `InlinedClosureVariable` is used to differentiate these cells from other cells.
Adding this test causes an error though (EDIT: this test is not relevant to this PR and instead just reveals that `cond` with Python side effects is still broken):
```python
def test_closure_out_of_scope_cell_with_cond(self):
from functorch.experimental.control_flow import cond
cell1 = torch.rand(3, 3)
cell2 = torch.rand(3, 3)
orig3 = torch.rand(3, 3)
def test(x):
cell3 = orig3.clone()
def then():
nonlocal cell3
cell3 += cell1
return cell3
def els():
nonlocal cell3
cell3 += cell2
return cell3
return cond(x > 0, then, els, [])
opt_fn = torch._dynamo.optimize("eager")(test)
result1 = opt_fn(1)
self.assertTrue(torch.allclose(result1, orig3 + cell1))
result2 = opt_fn(-1)
self.assertTrue(torch.allclose(result1, orig3 + cell1 + cell2))
```
```
Traceback (most recent call last):
File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1768, in test_closure_out_of_scope_cell_with_cond
result1 = opt_fn(1)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 295, in _fn
return fn(*args, **kwargs)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 448, in catch_errors
return callback(frame, cache_size, hooks, frame_state)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 526, in _convert_frame
result = inner_convert(frame, cache_size, hooks, frame_state)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 127, in _fn
return fn(*args, **kwargs)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 360, in _convert_frame_assert
return _compile(
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/utils.py", line 180, in time_wrapper
r = func(*args, **kwargs)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 430, in _compile
out_code = transform_code_object(code, transform)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
transformations(instructions, code_options)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 415, in transform
tracer.run()
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2029, in run
super().run()
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run
and self.step()
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step
getattr(self, inst.opname)(inst)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 391, in wrapper
return inner_fn(self, inst)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 1100, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 559, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1061, in call_function
(false_r, false_graph, false_lifted_freevars) = speculate_branch(False)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1044, in speculate_branch
ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph(
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 850, in speculate_subgraph
output = f.call_function(tx, args, {})
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/functions.py", line 121, in call_function
return tx.inline_user_function_return(
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 595, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2134, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2231, in inline_call_
tracer.run()
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run
and self.step()
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step
getattr(self, inst.opname)(inst)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 162, in impl
self.push(fn_var.call_function(self, self.popn(nargs), {}))
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/builtin.py", line 497, in call_function
proxy = tx.output.create_proxy(
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 345, in create_proxy
return self.current_tracer.create_proxy(*args, **kwargs)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1109, in create_proxy
new_arg = self.lift_tracked_freevar_to_input(arg)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1226, in lift_tracked_freevar_to_input
self.parent.lift_tracked_freevar_to_input(proxy)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1219, in lift_tracked_freevar_to_input
assert (
AssertionError: lift_tracked_freevar_to_input on root SubgraphTracer
from user code:
File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1766, in test
return cond(x > 0, then, els, [])
File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1764, in els
cell3 += cell2
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104222
Approved by: https://github.com/jansel
Fixes#99665
Let me explain the root cause using the unit test I added:
* This bug is triggered when:
* ```wrapped``` is a nested function.
* ```wrapped``` is in another module which is different from the main function ```fn```.
* There is a graph break inside of ```wrapped```.
* The root cause is when resuming nested function, actually we are using the outermost function(```fn``` in my example)'s global variables, but ```wrapped``` calls ```inner_func``` which is not part of ```fn```'s globals, so we have to set correct globals when nested function resume execution.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100426
Approved by: https://github.com/jansel
Fixes#99665
Let me explain the root cause using the unit test I added:
* This bug is triggered when:
* ```wrapped``` is a nested function.
* ```wrapped``` is in another module which is different from the main function ```fn```.
* There is a graph break inside of ```wrapped```.
* The root cause is when resuming nested function, actually we are using the outermost function(```fn``` in my example)'s global variables, but ```wrapped``` calls ```inner_func``` which is not part of ```fn```'s globals, so we have to set correct globals when nested function resume execution.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100426
Approved by: https://github.com/jansel
@wconstab As we discussed last Friday, I added the unit test for explicitly calling __call__ and added comment to explain why we redirecting ```UserMethodVariable.call_function``` to ```NNModuleVariable.call_method``` for a certain case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100146
Approved by: https://github.com/wconstab
It's part of the effort to improve PT2 Export UX. This PR is to improve the usability of `torch.cond()` by separating user errors from the dynamo internal errors. By definition, user error means the usage of `torch.cond()` violates the restrictions of this API therefore needs users to take action and fix the error.
In this notebook N3363227 we discovered a bunch of limitations of using `torch.cond(pred, true_fn, false_fn, operands)`. In summary, the limitations can be categorized as:
- predicate restriction (`pred`)
- operands restriction (`operands`)
- branch restriction (`true_fn` & `false_fn`)
The error message will be more accurate about where the (user) error is from and more actionable for users to fix it.
For example, `operands` must be a list of tensors and the signature of `true_fn` and `false_fn` must match with the `operands`.
If the operands contains non-tensor types, user will see error message like:
```
torch._dynamo.exc.UserError: Expected a list of tensors but got ["<class 'torch.Tensor'>", "<class 'float'>"]
from user code:
File "~/pytorch/test/dynamo/test_export.py", line 2504, in f_non_tensor_operands
return cond(True, lambda x, a: x.sin(), lambda x, a: x.cos(), [x, a])
```
If the signature of the branch function doesn't match with `operands`, user will see error message like:
```
torch._dynamo.exc.UserError: too many positional arguments.
func = 'false_fn' ~/pytorch/test/dynamo/test_export.py:2514, args = [<class 'torch.Tensor'>, <class 'torch.Tensor'>], kwargs = {}
```
Or if the tensor returned from user defined branches has different metadata, e.g. shapes, dtypes, etc., user will see error message like:
```
TypeError: Expected each tensor to have same metadata but got:
cond_true_0 returns TensorMetadata(shape=torch.Size([2, 1]), dtype=torch.int64, requires_grad=False, stride=(1, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})
cond_false_0 returns TensorMetadata(shape=torch.Size([1]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={})
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98909
Approved by: https://github.com/jansel
Summary of changes:
- Add CPython exceptiontable parsing/assembling functions in torch/_dynamo/bytecode_transformation.py, based on https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt.
- Add optional `exn_tab_entry` field to dynamo `Instruction`s in torch/_dynamo/bytecode_transformation.py in order to virtualize exception table entries (start, end, target instructions).
- Add checks guarding against duplicate instructions in dynamo, so that jump/exceptiontable targets are unambiguous. See `get_indexof` in torch/_dynamo/bytecode_analysis.py. Ensure that bytecode generation throughout dynamo does not generate duplicate instructions.
- Allow dynamo bytecode generation logic to generate nested exception table entries for developer convenience. CPython expects entries to not overlap, so we flatten nested entries during assembly in torch/_dynamo/bytecode_transformation.py:compute_exception_table.
- Simulate the block stack in torch/_dynamo/symbolic_convert.py. CPython removed the block stack in 3.11, but dynamo needs it in order to keep track of active contexts. So we simulate the block stack as before by looking at exceptiontable entries in order to determine the current blocks.
- Update context codegen in torch/_dynamo/resume_execution.py. The `SETUP_FINALLY` bytecode, which conveniently had a jump target to the finally block, was removed in 3.11, so we need to keep track of the jump target of the finally block using exceptiontables. Generating resume functions is more difficult since the original exceptiontable entries pointing to old cleanup code need to be modified to point to new cleanup code.
- Fix a push_null bug in torch/_dynamo/variables/functions.py introduced by https://github.com/pytorch/pytorch/pull/98699
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96511
Approved by: https://github.com/jansel, https://github.com/yanboliang, https://github.com/albanD
Enable some sensible flake8-simplify rules. Mainly wanted to enable the SIM101, and `yield from` SIM103 checks. @kit1980 since you wanted to be tagged on this CI check.
Enabling this check also helped flag one logical bug so it's definitely beneficial (also fixed in this PR).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97984
Approved by: https://github.com/ezyang
Handle tensor default func/method args when inlining
Previously, when inlining a function, its default arguments
were only wrapped with VariableTrackers if non-tensor. Now,
tensor default args are also handled by adding them to the
parent InstructionTranslator as an attribute.
- also patches up a missing source in nnmodule call_function,
needed to properly guard on a default arg in its methods
- adds new 'DefaultsSource' type which guards either a `__defaults__`
or `__kwdefaults__` entry on a function
Fixes#90361https://github.com/pytorch/torchdynamo/issues/1968
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90575
Approved by: https://github.com/voznesenskym
**Motivation**
When adding support for default args (#90575), a lot of VariableTrackers missing sources were encountered. Currently, in a lot of cases it seems OK to skip the source for VariableTrackers created (especially during inlining), but that assumption breaks down when inlining functions with default arguments.
**Summary** of changes
- propagate the self.source of the VariableBuilder to the new variables being built, which seems like it was an omission previously
- Add SuperSource to track usages of super(), so that SuperVariables can support function calls with default args
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91729
Approved by: https://github.com/ezyang
I'm going to need this in the follow up PR. Instead of storing only Source.name() in Symbol, I now store a full on Source. Lots of replumbing reoccurs. In particular:
- Move Source to torch._guards to break cycles
- I have to add TensorPropertySource and NegateSource to handle x.size()[0] and -x codegen that I was doing with string manipulation previously
- I tighten up invariants so that I never pass source=None; instead I pass ConstantSource (these are constant sources right) and test for that rather than source being missing. I think this is more parsimonious
- Some mypy wobbles from new imports
I didn't move LocalSource and friends to torch._guards, but I ended up needing to access them in a few places. The main annoyance with moving these is that then I also need to move the bytecode codegen stuff, and that's not so easy to move without bringing in the kitchen sink.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91057
Approved by: https://github.com/albanD, https://github.com/voznesenskym, https://github.com/zou3519
I'm going to need this in the follow up PR. Instead of storing only Source.name() in Symbol, I now store a full on Source. Lots of replumbing reoccurs. In particular:
- Move Source to torch._guards to break cycles
- I have to add TensorPropertySource and NegateSource to handle x.size()[0] and -x codegen that I was doing with string manipulation previously
- I tighten up invariants so that I never pass source=None; instead I pass ConstantSource (these are constant sources right) and test for that rather than source being missing. I think this is more parsimonious
- Some mypy wobbles from new imports
I didn't move LocalSource and friends to torch._guards, but I ended up needing to access them in a few places. The main annoyance with moving these is that then I also need to move the bytecode codegen stuff, and that's not so easy to move without bringing in the kitchen sink.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91057
Approved by: https://github.com/albanD, https://github.com/voznesenskym
Wow, I had to sweat so much to get this PR out lol.
This PR enforces the invariant that whenever we allocate SymInts as part of fakeification, the SymInt is associated with a Source, and in fact we store the string source name on SymbolWithSourceName. We use 'sname' as the shorthand for source name, as 'name' is already used by sympy to name symbols.
In order to store source names, we have to plumb source names from Dynamo to PyTorch. This made doing this PR a bit bone crushing, because there are many points in the Dynamo codebase where we are improperly converting intermediate tensors into fake tensors, where there is no source (and there cannot be, because it's a frickin' intermediate tensor). I've fixed all of the really awful cases in earlier PRs in the stack. This PR is just plumbing in source names from places where we do have it.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90295
Approved by: https://github.com/voznesenskym
The original implementation of cond() operator support in dynamo operated by recursively calling export() on the inner subgraph. This is problematic for a number of reasons:
* My original motivating reason: the original implementation had to play tricks to feed real tensors to the recursive export call, which means that it doesn't work well with tracing with dynamic shapes (where we MUST stay in fake tensors to accurately track dynamic shapes across the cond invocation)
* If there are pending side effects, the recursive export() call won't see those side effects (as they are only tracked by Dynamo, not actually applied to the Python environment.) You can see an example where dynamo cond tracing does the wrong thing at https://github.com/pytorch/pytorch/pull/90208
* If there were side effects inside the true/false branch, these side effects were silently lost (as the export only returns the graph of tensor operations, and not any of the residual Python bytecodes necessary to reapply any side effects.) This could have substantive effects on the export of subsequent parts of the model, as those parts of the models could rely on the side effects.
* It was not possible to track NN module accesses inside the true/false branches, necessitating a hack where the NN module was explicitly passed in as an input to cond https://github.com/pytorch/pytorch/pull/87020#issuecomment-1338842844 which doesn't really make any sense from a backend compilation perspective
* Guards induced from the inside of the true/false branch were not properly propagated to the top level guards; they were just silently dropped (in fact, the original implementation checked that the true/false branch produce the same guards which... is not useful? Like, I don't think that actually is even necessary for correctness)
This PR replaces the old implementation with a new implementation based on graphstate checkpointing. The basic idea is to process a cond(), we checkpoint the state of our interpreter, run the true branch, rollback to our checkpoint, run the false branch, rollback to our checkpoint and then merge the changes from both of the checkpoints. I require the true/false branches to have exactly the same side effects, but union their guards.
Some of the details:
* Dynamo is too aggressive with tracking side effects when processing closures, c.f. https://github.com/pytorch/torchdynamo/pull/233/files#r1040480078 The basic problem is whenever I define a closure, this immediately counts as a side effect, even if I didn't actually mutate anything. This triggered on the nested cond export example. To prevent this from happening, I optimistically avoid tracking side effects, but if a STORE_DEREF happens, I restart analysis with the relevant Source.name() added to `mutated_closure_cell_contents` so we start tracking on closure allocation. This is enough to fix the relevant test.
* For the most part, I assert that the graph states must be equivalent after applying the true/false branches. During debugging, I found it useful to be able to compare two graph states and give a better description about what the divergence was. You can test this using the `diff()` method I've added to a few structures.
* The implementation now supports NestedUserFunctionVariable, which is nice as it allows the true/false branches to be defined closer to the cond implementation.
* I fixed the naming of the true/false subgraphs; previously they were named `name_0`, `name_1`, now they are named `cond_true_0` and `cond_false_0`
* I added `name_to_input` to the saved graph state. I don't actually know if this is necessary, but it seemed like a good idea.
* I have to play some tricks to get the speculating execution of the true/false branch to record into a subgraph. After a careful read of OutputGraph, I found that what would work is overriding graph with a fresh Graph that we want to write things into, and manually setting up the inputs/outputs. It's a little delicate as you have to make sure you reset the Graph to its original before you restore a checkpoint, as checkpoints don't actually save graph for efficiency, and just undo changes on the graph. This capability may usefully get refactored to OutputGraph but I didn't do it in this PR for simplicity.
There are some further problems with the cond() implementation that I leave for future work. Most of these were preexisting with the original implementation.
* Not a problem per se, but if an NN module is used by both the true/false branch, it will show up in the final graph twice (since it has to be a submodule of the GraphModule that makes use of it.) I hope the export pipeline can deal with this.
* List of tensor output for cond is not supported.
* The true/false return values may not have consistent sizes/dims/etc, and we don't check them for consistency.
* If we modify fake tensors in the true/false branches, we aren't rolling them back, c.f. https://github.com/pytorch/torchdynamo/issues/1840
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90286
Approved by: https://github.com/voznesenskym