See added test for the case that this PR handles. In particular, the semantics for nested torch.compile with toggled fullgraph settings was strange before - `@torch.compile(fullgraph=True)` overrides the existing fullgraph setting, while `@torch.compile(fullgraph=False)` does not.
Note that this change will add an extra frame to any inlined torch.compile'd function (which I don't expect to happen frequently).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155166
Approved by: https://github.com/jansel
ghstack dependencies: #154283, #154289, #154782
See added test for the case that this PR handles. In particular, the semantics for nested torch.compile with toggled fullgraph settings was strange before - `@torch.compile(fullgraph=True)` overrides the existing fullgraph setting, while `@torch.compile(fullgraph=False)` does not.
Note that this change will add an extra frame to any inlined torch.compile'd function (which I don't expect to happen frequently).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155166
Approved by: https://github.com/jansel
ghstack dependencies: #154283, #154289, #154782
Lets explore firs a couple of problem related to replacements and runtime assertions.
#### example problem 1
if we have a runtime assertions that u0==s0, u0 is an input coming from mark_unbacked. A replacement u0=s0 will be added, the function f(u0, s0) will become f(s0, s0), this leads to the assert not being inserted during insert_deferred_runtime_asserts.
The reason is that insert_deferred_runtime_asserts logic insert each assertion once all its inputs are seen, but u0 will never be seen. Same thing can happen when we defer assertion on backed i.e: s0==s2 ..etc.
#### example problem 2
Consider u0==s0, where u0 is coming from a call to .item() Imagine later on that a specialization happens to s0 to become 2. In that case s0 as input wont be seen during insert_deferred_runtime_asserts and the assertion won't be inserted in the graph. Worse, Inductor will generate some code that refers to s0 in the cpp wrapper while it does not exist, causing a failure.
internal xref: https://fb.workplace.com/groups/1075192433118967/permalink/1669766396994898/
## The solution :
Runtime assertions insertion loops depend on detecting that the symbols that are used in the runtime assertions are seen, note that those symbols are either graph inputs or generated in the graph from data dependent ops like .item().
The issues above happen when symbols are graph inputs, in order to force the symbols to exist in the graph and to be seen by the runtime assertions we do not do replacements on placeholders expressions during codegen and during runtime assertions insertion.
This should not have performance overhead, since we already optimized the graph with replacements, the only effect is not mistakenly dropping graph inputs that are used in runtime assertions.
I added extended testing. A solo unrelated follow up that I noticed, is that we might want to rename unbacked symbols in runtime assertions when we do unbacked renaming, but that's a different issue.
Other approaches that did not work :
#### ban replacements on unbacked.
1. does not work when we defer runtime assertions on backed ex: s0==s1. we could also ban such replacements
but problem 2 becomes more problematic.
2. Problem two, it affects the quality of reasoning ! in a bad way.
#### Apply specialization on runtime assertions before codegen .
1. Can fix some issues, but may lead also to runtime assertions becoming NOPs.
2. Does not fix the issue if not inserting runtime assertions during insert_deferred_runtime_asserts due to input not being detected.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153661
Approved by: https://github.com/jansel
when a tensor has unbacked symbols it can be general enough to represent both contiguous and non contiguous tensors.
in that case we cant really evaluate is_contiguous. In many places in the code base, we check for is_contiguous to take a fast path. but the general path usually works for both contiguous and not contiguous in that case we probably want
to use definitely _contiguous API.
This is appleid for reshape in this PR and also to tensor meta data computation, the meta data now will have an attribute that says that its contiguous when its always contiguous. We would store that only if definitely _contiguous is true now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153432
Approved by: https://github.com/bobrenjc93
when a tensor has unbacked symbols it can be general enough to represent both contiguous and non contiguous tensors.
in that case we cant really evaluate is_contiguous. In many places in the code base, we check for is_contiguous to take a fast path. but the general path usually works for both contiguous and not contiguous in that case we probably want
to use definitely _contiguous API.
This is appleid for reshape in this PR and also to tensor meta data computation, the meta data now will have an attribute that says that its contiguous when its always contiguous. We would store that only if definitely _contiguous is true now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153432
Approved by: https://github.com/bobrenjc93
PR time benchmarks has been showing regressions as we move to guard_or_false, reason is that prev implementation do not cache.
This new approach will propagate the fallback value to eval and return it. allowing eval to cache and reducing scamming logs and complexity.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153674
Approved by: https://github.com/bobrenjc93
#### change 1: if compute_strides stride fail for reshape just clone.
Lets consider the most general case, if torch compile is asked to reshape [u0, u1][u3, u4] -> [u5, u6] what shall it do?
The shape is general enough to represent both contiguous and non contiguous tensors, tensors where a clone free reshape can happen and other where a clone free cant happen. The current algorithm will fail due to data dependent errors.
The general idea is if its impossible to tell if the reshape can happen in place, (because for some concrete inputs
it will and other not) then its ok to take the general path and clone, instead of failing or asking the user to give hints.
**Because the user want a single graph (single compilations)** and this is the only way it can be done.
Had this been a view? then the user is explicitly asking for a copy-free reshape, we would fail asking for more
information (hints in torch.checks form).
with this change reshape works as the following:
1. if we know the input is contiguous we will convert the reshape to view.
2. if compute_strides succeed we will use view. (compute_strides was changed to not fail when when unbacked presented instead it will just return nullptr if it cant compute the strides meaning we shall use a clone).
3. if neither 1, 2 works clone and use a view.
Side note: having a view does not mean that inductor will not clone, for inductor there is a pass that converts all views back to reshapes and inductor has its logic dealing with those.
#### change 2 : skip _reshape_view_helper and fall back to simpler logic if it fail.
We trace _reshape_view_helper when doing fake tensor tracing , but not during proxy tracing. hence such tracing wont effect the graph (only compute output shapes of several operations). We should not fail there, because it should always be possible for us to pass it in case of reshape.
i.e. when reshape_symint was called we would have either cloned, or compute_strides succeeded so the view should pass. What I did is the following: we run _reshape_view_helper, if we fail due to unbacked we call _view_simple which will succeed always for reshapes, (might fail for views when its impossible to do the view, in such case we throw the dde that was thrown by the original algorithm).
Ideally I would want to register _view_simple as the meta for view and avoid calling _reshape_view_helper completely but I am running some issues with the dispatcher with subclasses and I do not have time to debug it. Namely one test
would end up calling some c++ view function that does not support symints during meta dispatch when i register a
python meta decompositions
```python test/dynamo/test_subclasses.py SubclassTests.test_subclass_views_dynamic_True ```
https://github.com/pytorch/pytorch/issues/153303.I will follow up with that change in a separate PR. cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @bdhirsh
Two other alternatives for registering _view_simple as meta and the try catch approach in this PR is:
1. call _view_simple if any input is dynamic see #153521
2. if we make is_compiling works for framework code tracing (does not work rn) we can call _view_simple
is if is_compiling.
#### Note:
Reshape can still fail when is_contiguous is called, Next PR will handle that by calling is_known_contiguous.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153198
Approved by: https://github.com/etaf, https://github.com/bobrenjc93
A typical `bmm` kernel in Helion needs to pass in symint shapes to `torch.baddbmm`. Currently `self.expand((dim1, dim2, dim3))` in baddbmm runs unconditionally and it doesn't work with symint shapes (it raises the following error):
```
Traceback (most recent call last):
File "/home/willfeng/local/helion_yf225/helion/_compiler/type_propagation.py", line 699, in propagate_call
CheckForIndexCalls.retry_call(self.value, proxy_args, proxy_kwargs),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/helion_yf225/helion/_compiler/tile_index_proxy.py", line 104, in retry_call
return fn(*proxy_args, **proxy_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch/torch/utils/_stats.py", line 27, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch/torch/_subclasses/fake_tensor.py", line 1338, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch/torch/_subclasses/fake_tensor.py", line 1986, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch/torch/_subclasses/fake_tensor.py", line 1450, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch/torch/_subclasses/fake_tensor.py", line 2645, in _dispatch_impl
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch/torch/_ops.py", line 806, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch/torch/_prims_common/wrappers.py", line 309, in _fn
result = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch/torch/_meta_registrations.py", line 2172, in meta_baddbmm
self = self.expand((dim1, dim2, dim3))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: /home/willfeng/local/pytorch/build/aten/src/ATen/RegisterCompositeExplicitAutograd_0.cpp:5025: SymIntArrayRef expected to contain only concrete integers
```
This PR changes it so that we don't run `expand()` when not necessary, which makes the Helion use case (i.e. no broadcasting) work.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153112
Approved by: https://github.com/jansel
So when we use mark_unbacked the graph will have an unbacked inputs symInt. Right now,
deferred runtime assertions that uses those is never generated.
This PR changes that, such that in the forward graph we consider those and generate the corresponding
runtime assertions of them. We still ignore them for backward which is not ideal
The way we generate runtime assertion is by emitting them when all the defined unbacked symbols used
in them are seen.
We previously skipped placeholder, because for backward we have a wacky approach were we
ignore input defined unbacked symbols and assumes assertions that uses them are already emitted
in forward and we try to emit all other runtime assertions again. see [Note [Backwards runtime asserts]
Doing that we ends up only emitting the runtime assertions that depends on things defined solely in backward, but we could miss checks that spans inputs defined in both backward and forward, i.e one symbol defined in forward passed as input to backward., and another that is defined in backward.) .This is not ideal an ideal approach could be something like this https://github.com/pytorch/pytorch/pull/151919 but it require more work .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152231
Approved by: https://github.com/aorenste
Chatting with Bob the goal of this is to const fold the floats that where tensorified by calling
guard_scalar(val) on them and then replacing their usages by their values.
Hence we do not need to do this for nodes with no float symbols.
We do not want todo proper const folding because we need to preserve statements that deferred
runtime asserts depend on. (see the added test)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151494
Approved by: https://github.com/bobrenjc93
PR https://github.com/pytorch/pytorch/pull/149665 did a change to the optimized_add that is causing an issue internally.
In general make_optimized should be only be called with valid new_args, new_args can become None
when elements already exists also, we should break out of the loop in that case.
Note that I also only maintained the optimized summation when both lhs and rhs lengths are <=2.
This is ok because the optimization is based on the inductive property of adding one symbol at a time.
the [2]+[2] here is serving as base case ( i feel we can also remove it ) .
Note that keeping it for all sizes while correct, I am not sure if tis as efficient (we will do N log(n) insertions).
there is no current justification for that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150955
Approved by: https://github.com/Mingming-Ding, https://github.com/atalman, https://github.com/bobrenjc93
This PR was inspired by internal models that were cache missing due to PGO. At a high level the problem looks as follows
Run 1, Invocation 1: We do static compile, save some example values in PGO/automatic dynamic
Run 1, Invocation 2: We detect varying inputs, do dynamic compile, get a dynamic graph and save to PGO. Crucially what we save to PGO is actually a superset of what is actually dynamic. If we notice an input was varying, we mark it as dynamic in PGO even if later on that value gets specialized. When a value gets specialized, we actually remove the symbol from the graph. This results in an interesting conundrum where although we are producing the same isomorphic graph, PGO makes the second run cache miss. Let's see how....
Run 2, Invocation 1: We fetch the PGO, over-mark things as dynamic, get a fx graph, look it up in the cache and... whoops! cache miss! This is because of the aforementioned behavior where the PGO profile will cause us to over-allocate symbols. In practice this means we end up saving a graph in cache with symbols x:s1, y:s3 and on second attempt we cache miss with x:s1, y:s6 where symbols s3,s4,s5 were all optimistically marked dynamic by PGO and subsequently specialized.
We solve this problem by hashing the source names. This ensures somewhat stable assignment. To prevent catastrophic symbol collisions, we use linear probing to ensure no collisions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149665
Approved by: https://github.com/Mingming-Ding, https://github.com/laithsakka
This PR was inspired by internal models that were cache missing due to PGO. At a high level the problem looks as follows
Run 1, Invocation 1: We do static compile, save some example values in PGO/automatic dynamic
Run 1, Invocation 2: We detect varying inputs, do dynamic compile, get a dynamic graph and save to PGO. Crucially what we save to PGO is actually a superset of what is actually dynamic. If we notice an input was varying, we mark it as dynamic in PGO even if later on that value gets specialized. When a value gets specialized, we actually remove the symbol from the graph. This results in an interesting conundrum where although we are producing the same isomorphic graph, PGO makes the second run cache miss. Let's see how....
Run 2, Invocation 1: We fetch the PGO, over-mark things as dynamic, get a fx graph, look it up in the cache and... whoops! cache miss! This is because of the aforementioned behavior where the PGO profile will cause us to over-allocate symbols. In practice this means we end up saving a graph in cache with symbols x:s1, y:s3 and on second attempt we cache miss with x:s1, y:s6 where symbols s3,s4,s5 were all optimistically marked dynamic by PGO and subsequently specialized.
We solve this problem by hashing the source names. This ensures somewhat stable assignment. To prevent catastrophic symbol collisions, we use linear probing to ensure no collisions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149665
Approved by: https://github.com/Mingming-Ding, https://github.com/laithsakka
Adds option `torch.fx.experimental._config.backed_size_oblivious = True` to allocate `[0, inf]` instead of `[2, inf]` ranges for size backed symbols, and opting into size-oblivious semantics for them.
Helps in a number of cases like
- Keeps `[0, inf]` bounds for unbacked symbols, when we make a unbacked -> backed replacement
- More sound handling for 0/1 inputs at runtime when we lower from export
- Avoids ends-of-bounds, sys.maxsize constraint violations for exporting with named Dims (https://github.com/pytorch/pytorch/issues/146315, https://github.com/pytorch/pytorch/issues/146046)
May look towards turning this on globally for export.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148696
Approved by: https://github.com/bobrenjc93
Fixes https://github.com/pytorch/pytorch/issues/144095
open to suggestions: the `hint_int(..., fallback=...)` API feels like a bit of a footgun, because:
(1) we use the same guess for every unbacked symint (both symbols, and compound expressions)
(2) the user may have established some relationship between some unbacked symints that we are not taking into account.
I'm not sure how real of an issue (2) is - is it common to e.g. generate two unbacked symints, and then add a runtime assert that they are unequal?
Instead I did something simpler that's just enough to fix the linked issue: if we have a sympy expression containing an unbacked symbol (e.g. `u0 + 1`), then the partitioner will now fill in the symbol with our guess instead of the expression (plugging in `u0=4096` gets us 4097). This was important for an internal custom op, that had some logic like this:
```
def custom_op(x: [u0], y: [u0 + 1]):
assert x.shape[0] = y.shape[0] - 1
...
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144097
Approved by: https://github.com/laithsakka
Summary:
This very much the same solution proposed by bobrenjc93 except that it restrict it to expressions and axioms that have FloorDiv, since those are the only ones that could have became CleanDiv. and the only one that can changes as shape env changes.
This also does not break torchrec benchmarks, it might be worth it to know why the generalization of this does break the torchrec benchmarks, but we could just be hitting another bug or NYI situation.
ovearhead?
None on
```
buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=1000
```
Differential Revision: D66307433
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141267
Approved by: https://github.com/ezyang
Summary:
**wins**
on torchrec benchmark, for 2K nodes it save 40seconds
with the recent sympy changes (https://www.internalfb.com/diff/D65883538) we save around 13 second ( with the max opt on).
```
buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=200
```
This diff optimizes construction expressions of the form
a+b+c... (all unique symbols).
which are very common in torchrec models.
**How**
Expressions of the form a+b+c are not optimized by add, the only needed optimization is sorting them.
If we have a+b+c and we are adding (d) to it, we can do a binary search to know
the position of (d) and avoid optimizing the new expression by passing the new order.
**Extensions**:
1. support constant terms.
2. support 10a+10b+.. (this will give even more wins will extend the support in second PR)
Differential Revision: D66008482
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140822
Approved by: https://github.com/ezyang