TLDR; this PR supports exporting cond x inine_inbuilt nn modules flag by inling into tracing code in proxy_tensor.py _symbolic_trace.py (internally, the pattern is make_fx(record_module_stack)(torch.compile(f))).
We have two special treatments for following cases:
1. _ModuleStackTracer will wrap all the nn modules into _AttrProxy. This _AttrProxy has several subtiles which make it hard to inline in dynamo like overriding _modules with a property method and overrides the `__getattr__`, which mutates captured states when calling `__getattr__`.
Solution to this is that we unwrap the _AttrProxy and get its corresponding nn_module (a 1-1 correspondence). So that dynamo symbolically traces the original nn module instead of tracing _AttrProxy.
2. The tracer applies a bunch of patches the `__getattr__` and `__call__` of nn.Module for tracking reasons. This doesn't work well with dynamo. The immediate error we see is `torch._dynamo.exc.Unsupported: 'inline in skipfiles: WeakKeyDictionary.__contains__ | __contains__ /home/yidi/.conda/envs/pytorch/lib/python3.10/weakref.py` caused by a weakdict in PythonKeyTracer.
Solution to this is that we remove the patches during dynamo symbolic convert temporally. So that dynamo has a clean environment. make_fx will be trace the transformed bytecode of dynamo and patches nn modules there instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133731
Approved by: https://github.com/anijain2305
ghstack dependencies: #134775
This PR adds support for tracing `torch._C._pop_torch_function_stack()` without graph breaking and in order to verify the state change also adds replay of mutations to the torch function mode stack via side_effects appending supplemental bytecode as we do for other python mutable objects.
Details:
To represent the torch function mode stack symbolically a deque field is added to the instruction translator. When the InstructionTranslator is initialized, all modes are read from the current torch function mode stack, and stashed in a global weak ref for later access (using existing sources) without needing to push/pop the python/cpp torch function mode stack.
During tracing, when `_pop_torch_function_stack` is encountered a value is popped from this deque and the variable tracker representing the mode is returned. To ensure the true torch function mode stack matches this state, `TorchFunctionModeStackVariable`, a singleton, is marked as mutated, this adds it to side effects, where during final codegen, side effects will codegen a call to a python helper which will update the python torch function mode stack.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133131
Approved by: https://github.com/jansel
ghstack dependencies: #133130, #133729
All the changes brought by the original PR have been addressed in alternative ways in the stack. Why the original PR has to be reverted requires more effort because there is some bad interaction with export.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131058
Approved by: https://github.com/williamwen42
Fixes#129601
Background: it's possible that a traceable wrapper subclass will have an optional inner tensor constituent (e.g. NJT's cached min / max sequence lengths). To specify this, the subclass's `__tensor_flatten__()` impl should leave out any unspecified optional inner tensors in the returned list of `attrs`.
This PR guards on the list of inner tensor `attrs` returned in `subclass.__tensor_flatten__()[0]`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129618
Approved by: https://github.com/anijain2305
Significant bytecode generation API change!
The new suggested convention to generating bytecode to call a function is now to wrap instructions that push a callable to the stack with `add_push_null`, then that callable is called with `create_call_function` with `push_null=False` (see diff for examples).
In Python 3.13, NULL is now expected to be pushed after the callable. In <=3.12, the NULL was pushed before the callable. This change abstracts away the exact placement of the NULL, but the developer must be aware that a NULL may be needed when codegen'ing a callable.
This abstraction also reduces the need for the `push_null=True` option in `create_call_function`, which removes the need to rotate a NULL to the right place on the stack with a sequence of `SWAP` instructions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129172
Approved by: https://github.com/jansel
Fixes https://github.com/pytorch/pytorch/issues/125720
I was earlier worried that DELETE_* or STORE_* on referent values should result in a graph break, because they could invalidate the weak ref. But then @zou3519 pointed out that weakref invalidation will happen EVENTUALLY, CPython provides no guarantees when the weakref will be invalidated (even when the user calls del x and x is the last reference).
So any code that relies on del x to invalidate the weakref of x right away is BAD code. CPython provide no guarantees. Therefore we can (ab)use this nuance, and can just ignore DELETE_* or STORE_* on the referent objects.
The only corner case is when Dynamo is reconstructing the weakref object. Dynamo will have a hard time being correct here, so just SKIP_FRAME on such a case. This is rare.
Cpython notes
1) https://docs.python.org/3/library/weakref.html
2) https://docs.python.org/3/reference/datamodel.html#index-2
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128533
Approved by: https://github.com/jansel
This PR requires a little justification, but let's start with what it does first:
1. When you have a 0d CPU scalar int64/float64 tensor input to a graph, we will preallocate a backed SymInt/SymFloat corresponding to what you would get if you call item() on this tensor. This means you can freely change your input to be a Python int/float or a Tensor with an item() call and end up with exactly the same level of expressivity (specifically, you can guard on the internal SymInt/SymFloat no matter what). By default, the source of the backed SymInt/SymFloat is `L['tensor'].item()`, but if you have promoted a float input into a Tensor, we will cancel out `torch.as_tensor(L['float']).item()` into just `L['float']`.
2. We switch wrap_symfloat to use this, instead of hand crafting the new SymNodeVariable. Everything works out, except that we carefully pass the item() result to tracked fakes (and not the fake Tensor argument)
OK, so why do this at all? There is some marginal benefit where now some item() calls on scalar inputs can be guarded on, but IMO this is a pretty marginal benefit, and if it was the only reason, I wouldn't do this. The real reason for this is that I need to be able to propagate fake tensors through the graphs that are produced by Dynamo, and if I am doing the old custom wrap_symfloat logic, there's no way I can do this, because ordinarily an item() call will cause an unbacked SymInt when I reallocate.
The other obvious way to solve the problem above is to make a HOP alternative that item() that "bakes in" the backed SymInt its supposed to return. But this strategy seems more parsimonious, and it does have the marginal benefit I mentioned above. The main downside is that what I have to do next, is make it so that when I run tensor computation, I also apply the equivalent operations to the SymInt/SymFloat as well. That's next PR.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126245
Approved by: https://github.com/eellison
ghstack dependencies: #126637
The big idea is that floats are treated as Tensors on input/output to the FX graph, but on the inside, we immediately call item() on the synthetic Tensor and record regular float operations on it. Canonicalization to Tensor operations will happen in a standalone FX pass. This behavior is controlled by `specialize_float` config variable when set to False.
The generated graph looks like this for the test `test_unspec_float_output`:
```
def forward(self, L_x_: "f32[3]", L_y_: "f32[]"):
l_x_ = L_x_
l_y_ = L_y_
# File: /data/users/ezyang/a/pytorch/test/dynamo/test_unspec.py:511 in f, code: return x + 1, y * 2
add: "f32[3]" = l_x_ + 1; l_x_ = None
item: "Sym(zf0)" = l_y_.item(); l_y_ = None
mul: "Sym(2*zf0)" = item * 2; item = None
scalar_tensor: "f32[]" = torch.scalar_tensor(mul); mul = None
return (add, scalar_tensor)
```
The ingredients:
* **torch/_dynamo/variables/builder.py** When `specialize_float` is False, we wrap float literals with `wrap_symfloat`. This is an unholy mashup of `wrap_symint` and `wrap_unspecialized_primitive`. The overall strategy is that we first generate a tensor argument (because that's what we want to show up into the FX graph), but then immediately call item() on the tensor argument to get a SymNodeVariable, which we will do the rest of the tracing with. Importantly, this SymNodeVariable is backed with the source of the original float: this means we can guard on the resulting value (something we could NOT do with UnspecializedPythonVariable). This has to be done manually, because if you literally call item() on the tensor, you will end up with an unbacked float. There is a bit of copy paste from wrap_symint and wrap_unspecialized_primitive which we can try to factor out, but this really is its own thing and you should review every line of code in the function.
* **torch/fx/experimental/symbolic_shapes.py** We now can generate guards on float inputs, and these guards are handled inside of ShapeEnv. So we need to be able to allocate (backed!) float symbols, and produce guards for them. Fairly straightforward generalization.
* **torch/_dynamo/codegen.py** I also need to maintain the invariant that there are no float outputs to the FX graph. I chose to do this at codegen time. When we detect a SymNodeVariable on the return stack for a float, we on the fly convert it (via `as_tensor`) to a TensorVariable, which is the true output. We then special case the output bytecode to call item() on it again. The tensor conversion is memoized on SymNodeVariable since we typically run the code generation process twice.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125325
Approved by: https://github.com/lezcano, https://github.com/jansel
I am ok if people don't want this PR to be merged.
For optimizers, we know that the state dict and param_group have same parameters. So, I think its ok to skip TENSOR_MUST_ALIAS guards.
Similarly for state tensors, all of them are different. Therefore, we can skip the tensor aliasing guards.
With this PR, these are the numbers for Megatron which has 394 parameters
<img width="290" alt="image" src="https://github.com/pytorch/pytorch/assets/13822661/0ce75dc6-4299-46bb-bf3c-7989ebc7cfc4">
C++ numbers jump a lot because of 2 reasons
1) We are now not doing INCREF/DECREF for a large number of tensors.
2) For python guards, we can expect higher numbers but that requires some more plumbing because the Python tensor guards are all collapsed into one.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123044
Approved by: https://github.com/jansel, https://github.com/mlazos
Context: view fake-ification should handle closed-over state in ViewFuncs for use in view replay by:
* fake-ifying tensors
* symbolicizing SymInts
This avoids invalid specialization during view replay. However, the symbols / tensors created as intermediates in the view chain should not stick around or be guarded on. This PR introduces an `EphemeralSource` intended to be used as a source for this purpose. It has the following properties:
* Considered first to be simplified out in symbol simplification logic
* Errors if guarded on
Differential Revision: [D54561597](https://our.internmc.facebook.com/intern/diff/D54561597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120948
Approved by: https://github.com/ezyang
This adds support for backwards hooks that are *both*:
1) Interior to the graph; and
2) Dynamically generated (e.g. lambdas)
We do this by creating a BackwardState object that is used to register the hooks in the forward, then populated by dynamo *after* the forwards runs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120382
Approved by: https://github.com/xmfan
When building guards which went through a property we were analyzing the property using getattr_static but the guard wasn't built using getattr_static so if the property was "unusual" it generated misbehaved code which referenced a non-existent `__closure__` field.
Fixes#118786
Note that after this change some of the referenced tests are still failing with a different error - but getting further.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119719
Approved by: https://github.com/oulgen