Strategy taken from voz's #89392 but my implementation strategy
is a bit different.
If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test). Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89670
Approved by: https://github.com/voznesenskym
This is extracted from voz's #89392
Previously, the implementation did some half-assed caching where it
returned a callable, that when invoked for the first time, actually
performed the compilation. Delaying the compilation like this...
seems totally unnecessary? To make matters worse, this has cost
(we have to check if we hit the cache) and unsound (because the
compiled function may not be valid for other arguments.)
So instead, we ask user to provide arguments, and compile everything
immediately.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89669
Approved by: https://github.com/voznesenskym, https://github.com/Chillee
There was a lot of strangeness in how AOTAutograd backends were previously defined. This refactor replaces the strangeness with something simple and straightforward. The improvements:
- There is no longer a footgun aot_autograd "backend" which doesn't actually work. No more mistyping `torch._dynamo.optimize("aot_autograd")` when you meant "aot_eager"
- Deleted aot_print because it's annoying and anyway there's no uses of it
- Instead of having BOTH the backend Subgraph and AotAutogradStrategy, there is now only an aot_autograd function which takes the kwargs to configure AOTAutograd, and then gives you a compiler function that does AOTAutograd given those kwargs. Easy.
- The primary downside is that we are now eagerly populating all of the kwargs, and that can get us into import cycle shenanigans. Some cycles I resolved directly (e.g., we now no longer manually disable the forward function before passing it to aot_autograd; aot_autograd it does it for us), but for getting inductor decompositions I had to make it take a lambda so I could lazily populate the decomps later.
New code is 130 lines shorter!
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89736
Approved by: https://github.com/anjali411, https://github.com/albanD
A few things in this PR, that I found useful while debugging some
recent issues:
- We now allocate an aot_id to each aot_function/aot_module invocation,
and print it whenever we report error messages and graph output
logging. Check the comment for why this sort of thing is useful,
and also why it's different from nth_graph. This number is now
incorporated into aot_graph_name
- I noticed that nth_graph only gets incremented when backwards is
compiled. Because backwards is compiled lazily, this means that
multiple forward graphs would have gotten the same ID! I change
nth_graph to always increment to avoid confusion here.
- I added a simple describe_input function, which makes use of
num_params_buffers to tell the user if the input index they're
looking at is a param/buffer or an input. With the help of
https://github.com/pytorch/pytorch/pull/89709 we could give
even more detailed information about inputs (we could also
easily give detailed information about parameters if we stored
a mapping of index to parameter name, but I didn't need this
when debugging so I'll let someone else add it if they need
it.)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89710
Approved by: https://github.com/bdhirsh
It turns out that instead of having a giant blobby aot_dispatch_autograd
function, we can factor it into a series of wrapper functions, each
of which successively guarantees more invariants on the inner
compilation function until the final inner function is quite trivial.
How exactly you have to wrap the input user functions and the output
compiled functions can be expressed concisely in Haskell, so I've
included the Haskell formulation in code comments.
This PR shows how to do this for input deduplication. Dealing with the
rest of the view handling is left to future work.
This PR should also be a slight performance improvement as deduplicating
is skipped entirely when there are no duplicate inputs.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89701
Approved by: https://github.com/bdhirsh
The preexisting logic here added in
https://github.com/pytorch/functorch/pull/970 was very peculiar: if top_kwargs
was non-empty, then the inner compiled function supports kwargs. Naively, this
would leave you to expect that there is some sort of correlation between
top_kwargs and kwargs. But in fact, they're completely unrelated! top_kwargs
is the AOTAutograd configuration knobs (e.g., fw_compiler/bw_compiler), but
kwargs is the RUNTIME kwargs that are to be passed to the compiled function.
But (1) we don't support this (the function to be compiled only takes a list
of tensors) and (2) even if we did support it, conditioning on whether or not
you had passed AOTAutograd configuration kwargs to support kwargs at runtime
is bonkers.
So delete it.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89664
Approved by: https://github.com/voznesenskym
Using the same repro from the issue (but with BatchNorm2D)
Rectifies native_batch_norm schema by splitting the schema into 2:
1. one will have NON-optional alias-able running_mean and running_var inputs
2. the other will just not have those parameters at all (no_stats variation)
**Calling for name suggestions!**
## test plan
I've added tests in test_functionalization.py as well as an entry in common_method_invocations.py for `native_batch_norm_legit`
CI should pass.
## next steps
Because of bc/fc reasons, we reroute native_batch_norm to call our new schemas ONLY through the python dispatcher, but in 2 weeks or so, we should make `native_batch_norm_legit` the official batch_norm.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88697
Approved by: https://github.com/albanD
Summary:
We want to introduce an experimental control flow op: map() to export some models as FX graphs correctly.
Some calrification on basic requirements we have in mind:
1. This op can nest cond() and other control flow primitives internally.
2. We don't necessarily need loop carried dependencies for the models we've seen.
3. This map() op can handle dynamically shaped tensor as input and return dynamically shaped output based on input shapes.
4. We should be able to pass through additional arguments to the loop body as extra arguments.
In this diff we introduce a new control flow op `map()` which has the following semantics:
```
def map(f: Callable, xs: Tensor, *args):
# one possible implementation:
return torch.stack([f(x, *args) for x in xs])
```
Test Plan:
pytest functorch/test_control_flow.py
CI
Differential Revision: D41165796
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88767
Approved by: https://github.com/zou3519
This also comes with some bug fixes that were uncovered from doing
this:
- Forward device calls to inner tensor in FunctionalTensorWrapper
- Make legacyExtractDispatchKey exclude Functionalize, so that
it can get at the real device type key. This is noncontroversial.
- Stop stripping dense from key set. The reason for this is
FunctionalWrapperTensor may be used in contexts where people
query if it is dense or not. If it doesn't report this correctly
(from the dispatch key), it will cause errors. This caused some
torchbench models to fail when I did one-pass tracing.
- Save and restore reapply views TLS correctly
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88063
Approved by: https://github.com/bdhirsh
Previously the permute function was extended to behave like the `order`
function for first-class dimensions. However, unlike `permute`,
`order` doesn't have a keyword argment `dims`, and there is no way to add
it in a way that makes both permute an order to continue to have the same
behavior. So this change just removes the extra functionality of permute,
which wasn't documented anyway. Fixes#88187
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88226
Approved by: https://github.com/zou3519
This refactor was prompted by challenges handling mixed int/float
operations in C++. A previous version of this patch
added overloads for each permutation of int/float and was unwieldy
https://github.com/pytorch/pytorch/pull/87722/ This PR takes a different
approach.
The general outline of the patch is to combine the C++ types SymIntNode
and SymFloatNode into a single type, SymNode. This is type erased; we
no longer know statically at C++ if we have an int/float and have to test
it with the is_int()/is_float() virtual methods. This has a number of
knock on effects.
- We no longer have C++ classes to bind to Python. Instead, we take an
entirely new approach to our Python API, where we have a SymInt/SymFloat
class defined entirely in Python, which hold a SymNode (which corresponds
to the C++ SymNode). However, SymNode is not pybind11-bound; instead,
it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode
when it goes into C++. This implies a userland rename.
In principle, it is also possible for the canonical implementation of SymNode
to be written in C++, and then bound to Python with pybind11 (we have
this code, although it is commented out.) However, I did not implement
this as we currently have no C++ implementations of SymNode.
Because we do return SymInt/SymFloat from C++ bindings, the C++ binding
code needs to know how to find these classes. Currently, this is done
just by manually importing torch and getting the attributes.
- Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now
takes SymInt/SymFloat, rather than SymNode, bringing it in line with how
__torch_dispatch__ works.
Some miscellaneous improvements:
- SymInt now has a constructor that takes SymNode. Note that this
constructor is ambiguous if you pass in a subclass of SymNode,
so an explicit downcast is necessary. This means toSymFloat/toSymInt
are no more. This is a mild optimization as it means rvalue reference
works automatically.
- We uniformly use the caster for c10::SymInt/SymFloat, rather than
going the long way via the SymIntNode/SymFloatNode.
- Removed some unnecessary toSymInt/toSymFloat calls in normalize_*
functions, pretty sure this doesn't do anything.
- guard_int is now a free function, since to guard on an int you cannot
assume the method exists. A function can handle both int and SymInt
inputs.
- We clean up the magic method definition code for SymInt/SymFloat/SymNode.
ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets
plain methods; this is to help avoid confusion between the two types.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87817
Approved by: https://github.com/albanD, https://github.com/anjali411
Previously we claimed that "forward-mode AD coverage is not that good".
We've since improved it so I clarified the statement in our docs and
downgraded the warning to a note.
Test Plan:
- view docs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87383
Approved by: https://github.com/samdow
Fixes https://github.com/pytorch/functorch/issues/1052
I got here after some discussion with Alban. Today, if you aot_function() trace a program where some of its inputs have `requires_grad=True`, but some outputs are expected to have `requires_grad=False`, we will incorrectly set all outputs to have `requires_grad=True`.
A simple solution is to use autograd.function's API for marking outputs as non-differentiable, based on what we witnessed when we traced the forward.
This will make the `autograd.Function` that we return **wrong**, if you created it using inputs that required grad, and tried to re-use it with inputs that have different `requires_grad` field. But as long as we're hiding behind dynamo, which should guard on requires_grad, then we'll re-run `aot_function()` and get out a new compiled function that does the right thing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86838
Approved by: https://github.com/ezyang
We initially left it there for BC concerns.
- It has been more than a month since then,
- I have migrated folks who used the previous install command (pip
install ...pytorch.git@subdir=functorch) off of it
so it's time to get rid of it
Test Plan:
- code reading
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87235
Approved by: https://github.com/Chillee
`min_cut_rematerialization_partition` has a default set of hard-coded operations that are allowed to be recomputed in the backward pass.
This PR adds customization ability to this function allowing users to control the behavior by passing `recomputable_ops` instead of relying on the default setting.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86686
Approved by: https://github.com/Chillee
Added some details about:
- `pip uninstall functorch` being helpful if there are problems
- `pip install functorch` still working for BC reasons.
Test Plan:
- wait for docs preview
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86823
Approved by: https://github.com/samdow
Fix for https://github.com/pytorch/torchdynamo/issues/1368
From comment:
> When we invoke a Composite Implicit autograd operator that has an autocast rule, such as Einsum,
autocast is disabled during its invocation. When we trace out the operators in an implicit op,
re-applying on autocast rules on those operators might yield divergence from what was executed at runtime.
This pass checks for divergence. If divergence is found, we will disable autocast.
We would like to avoid disabling autocast if possible because accessing TLS is slow.
Concretely, the problem found was when invoked `sum` in `einsum`:
As seen by the following divergence:
```
>>> with torch.cuda.amp.autocast(enabled=True):
... print(torch.ops.aten.sum.dim_IntList(torch.rand([2, 2, 2], device="cuda", dtype=torch.half), [1, 2]).dtype)
...
torch.float32
>>> print(torch.ops.aten.sum.dim_IntList(torch.rand([2, 2, 2], device="cuda", dtype=torch.half), [1, 2]).dtype)
torch.float16
```
Edit: we've decided to accept the overhead of universally disabling autocast instead
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86515
Approved by: https://github.com/bdhirsh, https://github.com/Chillee
- supports saving symint (and symfloat..) values between fw/bwd, using sketchy logic that probably needs to be improved but seems to work so far
- sets a correct weight=1 for sym nodes for cost purposes
- lets user functions return symints/floats (but if the same symfloat is saved for backward, that gets duplicated annoyingly)
- makes partitioning decisions based on observed trace-time sizes without guarding! (this is sketchy, but it isn't clear that it will lead to bad partitioning choices either)
- improves infra for tracking symint-family of types: is_sym_node() and _py_sym_types
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86425
Approved by: https://github.com/ezyang
symintify split_with_sizes, dropout, fused_fake_obs_quant. meta for padding_2d ops
add meta_bernoulli_
meta kernel for at::gather
get pytorch_struct to pass: meta for scatter_add, fix backward
symintify split ops
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86488
Approved by: https://github.com/ezyang