I regretted the decision in
https://github.com/pytorch/pytorch/pull/130606. Most user
torch_dispatchs don't have enough to actually handle the HOP correctly,
so for now I'd prefer that users explicitly define the interaction
between the HOP and their torch_dispatch class.
An example is FlopCounterMode: if we allow HOPs to get passed to it, it
will ignore auto_functionalized(mm) by default but it will record flops
for mm, which is weird.
Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131370
Approved by: https://github.com/ydwu4
This involved beefing up the Python dispatcher to handle torch_dispatch.
Given a HOP and a torch_dispatch Tensor subclass:
- the HOP will show up in the subclass's `__torch_dispatch__`
- you can also use HOP.py_impl to register a rule for the HOP x
subclass interaction
- (coming soon) we'll offer a way to open register HOP x subclass
interaction without needing to touch the subclass's
`__torch_dispatch__` or the HOP's .py_impl.
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130606
Approved by: https://github.com/ydwu4
If a user accesses an OpOverloadPacket, then creates a new OpOverload,
then uses the OpOverloadPacket, the new OpOverload never gets hit. This
is because OpOverloadPacket caches OpOverloads when it is constructed.
This PR fixes the problem by "refreshing" the OpOverloadPacket if a new
OpOverload gets constructed and the OpOverloadPacket exists.
Test Plan:
- new tests
This is the third land attempt. The first one was reverted for breaking
internal tests, the second was reverted for being erroneously suspected
of causing a perf regression.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128000
Approved by: https://github.com/albanD
If a user accesses an OpOverloadPacket, then creates a new OpOverload,
then uses the OpOverloadPacket, the new OpOverload never gets hit. This
is because OpOverloadPacket caches OpOverloads when it is constructed.
This PR fixes the problem by "refreshing" the OpOverloadPacket if a new
OpOverload gets constructed and the OpOverloadPacket exists.
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126863
Approved by: https://github.com/albanD
This adds a new dispatch mode, PreDispatchSchemaCheckMode, built on top of SchemaCheckMode, used for verifying op schemas for functionalization for PreDispatch IR. More specifically, the mode runs in eager mode on concrete inputs, checking if op schemas incorrectly claim to be functional, but are aliasing or mutating. This mode is pushed to the pre-dispatch mode stack, and run before decompositions.
Current testing is hooked up to OpInfo, containing 1103 tests on 600 unique ops. Below is a list of ops that fail testing. One caveat is we only raise errors on ops that claim to be functional - if an op schema admits aliasing or mutating but fails testing for the other, it still may decompose further and become functional.
List of failed ops:
```
aten.atleast_1d.default
aten.atleast_2d.default
aten.atleast_3d.default
aten.cartesian_prod.default
aten.conj_physical.default
aten.alpha_dropout.default
aten.feature_dropout.default
aten.feature_alpha_dropout.default
aten.unsafe_chunk.default
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125481
Approved by: https://github.com/tugsbayasgalan
A re-land of #124239.
This PR fakify ScriptObject inputs and attributes in export non-strict mode by default.
The basic idea is to only fakify the script object during tracing (i.e. aot_export). After we get the traced graph module, eagerly executing, serializing, or running more passes will use the real script objects. This is essentially treating the script object as constant tensor.
Concretely, we
fakify all the script object inputs, and module attributes (gathered by constant_attrs).
patch the module's attributes with fakified script object
right after aot_export, remove the patching (to avoid changing the original module) then modify the exported graph module's attribute to real script object.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125490
Approved by: https://github.com/angelayi
This PR fakify ScriptObject inputs and attributes in export non-strict mode by default.
The basic idea is to `only fakify the script object during tracing (i.e. aot_export)`. After we get the traced graph module, eagerly executing, serializing, or running more passes will use the real script objects. This is essentially treating the script object as constant tensor.
Concretely, we
1. fakify all the script object inputs, and module attributes (gathered by constant_attrs).
2. patch the module's attributes with fakified script object
3. right after aot_export, remove the patching (to avoid changing the original module) then modify the exported graph module's attribute to real script object.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124239
Approved by: https://github.com/zou3519
fake_tensor.py had mypy error ignored. That seems less than desirable.
Also added SafePyObjectT<T> which is a tagged wrapper around a SafePyObject but provides static type checking (with no other guarantees).
Used `SafePyObjectT<TorchDispatchModeKey>` on some of the TorchDispatchModeTLS API to ensure that we don't accidentally inject a different type than expected into the stack.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124428
Approved by: https://github.com/malfet
fake_tensor.py had mypy error ignored. That seems less than desirable.
Also added SafePyObjectT<T> which is a tagged wrapper around a SafePyObject but provides static type checking (with no other guarantees).
Used `SafePyObjectT<TorchDispatchModeKey>` on some of the TorchDispatchModeTLS API to ensure that we don't accidentally inject a different type than expected into the stack.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124428
Approved by: https://github.com/malfet
If a user accesses an OpOverloadPacket, then creates a new OpOverload,
then uses the OpOverloadPacket, the new OpOverload never gets hit. This
is because OpOverloadPacket caches OpOverloads when it is constructed.
This PR fixes the problem by "refreshing" the OpOverloadPacket if a new
OpOverload gets constructed and the OpOverloadPacket exists.
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124654
Approved by: https://github.com/albanD
We override the `__call__` method and register fake, functional, proxy default dispatch mode implementation in its python_key_mode_table.
The idea is:
1. when inputs contains FakeScriptObject, we dispatch it through _get_dispatch mechanism. We implement dispatch mode keys automatically in the operator's constructor.
2. when inputs are not fakified, we dispatch through the original c++ dispatcher.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123367
Approved by: https://github.com/zou3519
A kernel has "dispatcher convention" if there is an additional keyset
arg at the beginning of the argument list. This PR:
- adds a way to register kernels with dispatcher_convention using
Library.impl (pass dispatcher_convention = True)
- adds OpOverload.redispatch
We use both of the above in the new custom ops API: we register the
autograd kernel in dispatcher convention so that we can actually call
redispatch like how pytorch built-in ops do it.
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124089
Approved by: https://github.com/albanD
ghstack dependencies: #123937, #124064, #124065, #124066, #124071
If a user accesses an OpOverloadPacket, then creates a new OpOverload,
then uses the OpOverloadPacket, the new OpOverload never gets hit. This
is because OpOverloadPacket caches OpOverloads when it is constructed.
This PR fixes the problem by "refreshing" the OpOverloadPacket if a new
OpOverload gets constructed and the OpOverloadPacket exists.
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123578
Approved by: https://github.com/albanD
ghstack dependencies: #123453
When we enter map_autograd, we try to trace through fwd/bwd of a map operator that is wrapped in ctx.functionalize wrapper. This forces us to go through PreDispatch functionalization again (only the python part). As a result, it revealed our previous bug where pre-dispatch mode handling doesn't actually manage the local dispatch key set. (If there is no active mode, we need to turn off PreDispatch key). This PR fixes that. Also I shuffled some APIs around so that there is less code duplication as the setting/unsetting logic is quite hard to get it right.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121444
Approved by: https://github.com/bdhirsh
When we enter map_autograd, we try to trace through fwd/bwd of a map operator that is wrapped in ctx.functionalize wrapper. This forces us to go through PreDispatch functionalization again (only the python part). As a result, it revealed our previous bug where pre-dispatch mode handling doesn't actually manage the local dispatch key set. (If there is no active mode, we need to turn off PreDispatch key). This PR fixes that. Also I shuffled some APIs around so that there is less code duplication as the setting/unsetting logic is quite hard to get it right.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121444
Approved by: https://github.com/bdhirsh
When we enter map_autograd, we try to trace through fwd/bwd of a map operator that is wrapped in ctx.functionalize wrapper. This forces us to go through PreDispatch functionalization again (only the python part). As a result, it revealed our previous bug where pre-dispatch mode handling doesn't actually manage the local dispatch key set. (If there is no active mode, we need to turn off PreDispatch key). This PR fixes that. Also I shuffled some APIs around so that there is less code duplication as the setting/unsetting logic is quite hard to get it right.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121444
Approved by: https://github.com/bdhirsh
In case no keyword arguments are passed, `**kwargs` would expand just fine without the need for extra overhead of `or {}`. In addition to reducing boilerplate, this also comes with a small perf improvement:
```
In [1]: def null(*args, **kwargs):
...: pass
...:
In [2]: def call1(*args, **kwargs):
...: return null(*args, **(kwargs or {}))
...:
In [3]: def call2(*args, **kwargs):
...: return null(*args, **kwargs)
...:
In [4]: %timeit call1()
145 ns ± 2.07 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
In [5]: %timeit call2()
118 ns ± 2.14 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
In [6]: %timeit call1()
147 ns ± 6.19 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
In [7]: %timeit call2()
117 ns ± 0.846 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117880
Approved by: https://github.com/Skylion007
In this PR, we are implementing Functionalization on pre-dispatch graph. Today, every dispatch key except for Dispatchkey.Python has a dedicated mode stack in python. PreDispatch tracing relies on this behaviour by pushing ProxyTorchDispatchMode to Dispatchkey.PreDispatch mode stack and handle the dispatching logic in python. To make pre-dispatch functionalization work, we now need to push FunctionalTensorMode on DispatchKey.PreDispatch mode stack and make sure it runs before ProxyTorchDispatchMode. (this is very similar to how post-dispatch tracing work). Here are some design decisions we made for this flow to work:
1. FunctionalTensorMode internally calls C++ functionalize key. Since C++ functionalization goes after PreDispatch, if we are not careful, we will keep re-entering into PreDispatch key. We solve this by directly dispatching to C++ Functionalize key.
2. We delete mode_stack_per_key logic because the only realistic time it is exercised is for PreDispatch and it is in general not safe to have a plain list because FunctionalTensorMode and ProxyTorchDispatchMode ordering matter and it is hard to enforce it on plain list. Instead, now we have a private class that tracks PreDispatch mode stack.
3. We will still run CompositeImplicitAutograd decomps in this PR, and disable this logic later as a followup.
Some missing bits after this PR:
1. Preserving autograd ops in a functional form. Right now they still show up in the graph but in a "non-functional" way.
2. Turn off CompositeImplicitAutograd decomps
3. Functionalizing HOO
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113728
Approved by: https://github.com/bdhirsh
Summary:
This is important for writing aten IR based graph transformation.
```
In [4]: [x.name for x in torch.ops.aten.reshape.default._schema.arguments]
Out[4]: ['self', 'shape']
In [8]: torch.ops.aten.reshape.default(torch.rand(1,2), shape=[2])
Out[8]: tensor([0.7584, 0.4834])
# === CANNOT CALL `self` BY KWARGS ===
In [7]: torch.ops.aten.reshape.default(self=torch.rand(1,2), shape=[2])
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[7], line 1
----> 1 torch.ops.aten.reshape.default(self=torch.rand(1,2), shape=[2])
TypeError: OpOverload.__call__() got multiple values for argument 'self'
```
# Where's the problem?
1. the aten ops first arg is usually named `self` (aten/src/ATen/native/native_functions.yaml)
2. Unfortunately, in `torch._ops.{OpOverload, OpOverloadPacket}.__call__()`, the first arg is (by python convention) named `self` too.
So when call `self` by kwargs, `OpOverloadPacket.__call__` received:
```
OpOverloadPacket.__call__(self, {"self": ...})
```
It is Python that does not allow some argument named "arg" to appear twice. and hence
> TypeError: OpOverload.__call__() got multiple values for argument 'self'
# How to fix?
**Note that**, in above, `self` is an instance of `OpOverloadPacket`, and the "self" kwarg is the input tensor to the aten op. To fix, we only need to differentiate the two `self`s.
In Python, first arg of a method does not need to be named `self`. So we change the `__call__` definition to:
```
def __call__(_self, ...):
```
Now the call becomes:
```
OpOverloadPacket.__call__(_self, {"self": ...})
```
where:
* `_self` is the instance to the `OpOverloadPacket`
* `"self"` is the input tensor to the aten op.
Test Plan:
```
In [4]: [x.name for x in torch.ops.aten.reshape.default._schema.arguments]
Out[4]: ['self', 'shape']
In [3]: torch.ops.aten.reshape.default(self=torch.rand(1,2), shape=[2])
Out[3]: tensor([0.5127, 0.3051])
```
Differential Revision: D51731996
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114920
Approved by: https://github.com/houseroad
Users may wish to torch.compile custom ops that mutate their inputs
and return nothing (this is a common class of operators).
torch.compile will automatically support this op without anyone needing
to provide a functionalization kernel for it. Here's how.
Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> ()
op. First, when FakeTensor sees this op, it can just return None.
This is the case because custom ops are not allowed to mutate input
metadata, so the FakeTensor rule for one that returns nothing is trivial.
Next, when Python FunctionalTensor sees the op, it will functionalize
it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...})
HOP and replacing the mutated inputs with the outputs of this HOP.
This HOP effectively runs the functional version of the op when
called: it clones inputs that will be mutated, runs the op, and
then returns Tensors with the new values.
In the future we can teach Inductor how to do re-inplacing when it sees
this HOP (like how triton kernels do it) but this isn't urgent (and is
more of a performance problem).
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114955
Approved by: https://github.com/bdhirsh