Commit Graph

206 Commits

Author SHA1 Message Date
e1c2bdac2f [easy] fix f-string messages in torch/_ops.py (#132531)
I encountered these when making this change:

```
diff --git a/test/functorch/test_ac.py b/test/functorch/test_ac.py
index 3a2e07fa147..a4d003399e7 100644
--- a/test/functorch/test_ac.py
+++ b/test/functorch/test_ac.py
@@ -259,15 +259,8 @@ class MemoryBudgetTest(TestCase):

         expected = call()
         for budget in range(0, 11):
-            memory_budget = budget / 10
-            torch._dynamo.reset()
-            with config.patch(activation_memory_budget=memory_budget):
-                if memory_budget is not None:
-                    f_compile = torch.compile(
-                        call, backend="aot_eager_decomp_partition"
-                    )
-
-                self.assertEqual(expected, f_compile())
+            get_mem_and_flops(call, memory_budget=budget / 10)
+

     def test_prioritize_cheaper_matmul(self):
         def f(xs, ws):
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132531
Approved by: https://github.com/Skylion007
2024-08-05 18:58:33 +00:00
5dac4d2c78 Revert "[easy] fix f-string messages in torch/_ops.py (#132531)"
This reverts commit 908d2a153b14cbb7a39c1f4ef9a77534cf2c71bf.

Reverted https://github.com/pytorch/pytorch/pull/132531 on behalf of https://github.com/davidberard98 due to still breaks tests ([comment](https://github.com/pytorch/pytorch/pull/132531#issuecomment-2267584289))
2024-08-04 15:41:56 +00:00
908d2a153b [easy] fix f-string messages in torch/_ops.py (#132531)
I encountered these when making this change:

```
diff --git a/test/functorch/test_ac.py b/test/functorch/test_ac.py
index 3a2e07fa147..a4d003399e7 100644
--- a/test/functorch/test_ac.py
+++ b/test/functorch/test_ac.py
@@ -259,15 +259,8 @@ class MemoryBudgetTest(TestCase):

         expected = call()
         for budget in range(0, 11):
-            memory_budget = budget / 10
-            torch._dynamo.reset()
-            with config.patch(activation_memory_budget=memory_budget):
-                if memory_budget is not None:
-                    f_compile = torch.compile(
-                        call, backend="aot_eager_decomp_partition"
-                    )
-
-                self.assertEqual(expected, f_compile())
+            get_mem_and_flops(call, memory_budget=budget / 10)
+

     def test_prioritize_cheaper_matmul(self):
         def f(xs, ws):
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132531
Approved by: https://github.com/Skylion007
ghstack dependencies: #132356, #132466
2024-08-04 14:30:42 +00:00
21d02f8b4b Revert "[easy] fix f-string messages in torch/_ops.py (#132531)"
This reverts commit 25903f3932b3a24d4edf323484d2159f3ac92999.

Reverted https://github.com/pytorch/pytorch/pull/132531 on behalf of https://github.com/davidberard98 due to broke lint and tests due to conflict with 132377 ([comment](https://github.com/pytorch/pytorch/pull/132531#issuecomment-2266743391))
2024-08-03 14:49:07 +00:00
25903f3932 [easy] fix f-string messages in torch/_ops.py (#132531)
I encountered these when making this change:

```
diff --git a/test/functorch/test_ac.py b/test/functorch/test_ac.py
index 3a2e07fa147..a4d003399e7 100644
--- a/test/functorch/test_ac.py
+++ b/test/functorch/test_ac.py
@@ -259,15 +259,8 @@ class MemoryBudgetTest(TestCase):

         expected = call()
         for budget in range(0, 11):
-            memory_budget = budget / 10
-            torch._dynamo.reset()
-            with config.patch(activation_memory_budget=memory_budget):
-                if memory_budget is not None:
-                    f_compile = torch.compile(
-                        call, backend="aot_eager_decomp_partition"
-                    )
-
-                self.assertEqual(expected, f_compile())
+            get_mem_and_flops(call, memory_budget=budget / 10)
+

     def test_prioritize_cheaper_matmul(self):
         def f(xs, ws):
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132531
Approved by: https://github.com/Skylion007
ghstack dependencies: #132356, #132466
2024-08-03 02:23:44 +00:00
ff4ca0d02a [Easy] Fix argument name collision in HigherOrderOperator dispatched functions (#132377)
Share the same spirit of #129562

- #129562

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132377
Approved by: https://github.com/zou3519
2024-08-01 17:13:37 +00:00
5612408735 _get_operation_overload: dont raise exception when overload does not exist (#131554)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131554
Approved by: https://github.com/ezyang, https://github.com/zou3519
ghstack dependencies: #131403, #131482, #131665
2024-07-26 15:38:11 +00:00
4ac77fc6bd [HOP] Don't send HOPs to torch_dispatch (#131370)
I regretted the decision in
https://github.com/pytorch/pytorch/pull/130606. Most user
torch_dispatchs don't have enough to actually handle the HOP correctly,
so for now I'd prefer that users explicitly define the interaction
between the HOP and their torch_dispatch class.

An example is FlopCounterMode: if we allow HOPs to get passed to it, it
will ignore auto_functionalized(mm) by default but it will record flops
for mm, which is weird.

Test Plan:
- tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131370
Approved by: https://github.com/ydwu4
2024-07-23 13:41:08 +00:00
b29b23137c [Easy] Fix argument name collision in dispatched functions (#129562)
Use positional-only argument to avoid naming collision with aten ops arguments that are named "self".

```python
In [1]: def foo(self, *args, **kwargs):
   ...:     print(self, args, kwargs)
   ...:

In [2]: def bar(self, /, *args, **kwargs):
   ...:     print(self, args, kwargs)
   ...:

In [3]: foo(1, 2, self=3)
TypeError: foo() got multiple values for argument 'self'

In [4]: bar(1, 2, self=3)
1
(2,)
{'self': 3}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129562
Approved by: https://github.com/zou3519, https://github.com/fegin
2024-07-17 14:39:56 +00:00
95046c86e3 [HOP] add HOP x torch_dispatch interaction (#130606)
This involved beefing up the Python dispatcher to handle torch_dispatch.
Given a HOP and a torch_dispatch Tensor subclass:
- the HOP will show up in the subclass's `__torch_dispatch__`
- you can also use HOP.py_impl to register a rule for the HOP x
  subclass interaction
- (coming soon) we'll offer a way to open register HOP x subclass
  interaction without needing to touch the subclass's
  `__torch_dispatch__` or the HOP's .py_impl.

Test Plan:
- new tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130606
Approved by: https://github.com/ydwu4
2024-07-12 21:51:36 +00:00
f093cd4086 Fix custom ops warning during export (#130623)
Fixes https://github.com/pytorch/pytorch/issues/130588

The problem was we were warning on all custom ops, not just ones marked
as CompositeImplicitAutograd. This PR changes the warning to just warn
on CompositeImplicitAutograd ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130623
Approved by: https://github.com/williamwen42
2024-07-12 21:34:29 +00:00
b4cc25f126 [custom_op]Fix self in mutation_args (#130179)
Fixes #124933

## Issue Summary
If users define `self` as mutate args, there is an error occurs `TypeError: AutoFunctionalized.__call__() got multiple values for argument 'self'`. For the following example, the schema for mutates_args is parsed as {"self": FakeTensor}.  6df963a2c8/torch/_higher_order_ops/auto_functionalize.py (L234)
In the above line, it is unwrapped as `self=FakeTensor` and leads to wrong argument pass because `self` is the default keyword for functions of a class, such as https://github.com/pytorch/pytorch/compare/main...findhao/fix-self-custom-ops#diff-9453b6b52a54783beec3dd1c60248620f61c3a524d404a188af17bbdf6be3d9eR292 .
```python
import torch

@torch.library.custom_op("mylib::foo", mutates_args={"self"})
def foo(self: torch.Tensor) -> None:
    self.sin_()

x = torch.randn(3)

@torch.compile(backend="inductor", fullgraph=True)
def f(x):
    foo(x)

f(x)
```
## Fix
This PR changes all related default argument `self` to `self_` following the existing way in 6fc771d19b/torch/_ops.py (L667)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130179
Approved by: https://github.com/zou3519
2024-07-08 22:55:50 +00:00
e9c6e8369c Torchbind call method + effects support (#128397)
Adds effect token support to torchbind method calls by allowing `with_effects` to take in `torch.ops._higher_order_ops.call_torchbind` as an input.

Here is the print from `TORCH_LOGS="aot" python test/export/test_torchbind.py -k test_compile_obj_torchbind_op`:
```python
def forward(self, arg0_1: "f32[0]", arg1_1: "f32[2]", arg2_1):
    # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1266 in f, code: torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
    cos: "f32[2]" = torch.ops.aten.cos.default(arg1_1)
    with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops._TorchScriptTesting.queue_push.default, arg2_1, cos);  arg0_1 = cos = None
    getitem: "f32[0]" = with_effects[0];  with_effects = None

    # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1267 in f, code: torch.ops._TorchScriptTesting.queue_push(tq, x.cos() + 1)
    cos_1: "f32[2]" = torch.ops.aten.cos.default(arg1_1)
    add: "f32[2]" = torch.ops.aten.add.Tensor(cos_1, 1);  cos_1 = None
    with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.queue_push.default, arg2_1, add);  getitem = add = None
    getitem_2: "f32[0]" = with_effects_1[0];  with_effects_1 = None

    # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1268 in f, code: torch.ops._TorchScriptTesting.queue_pop(tq)
    with_effects_2 = torch._higher_order_ops.effects.with_effects(getitem_2, torch.ops._TorchScriptTesting.queue_pop.default, arg2_1);  getitem_2 = None
    getitem_4: "f32[0]" = with_effects_2[0];  with_effects_2 = None

    # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1269 in f, code: torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
    sin: "f32[2]" = torch.ops.aten.sin.default(arg1_1);  arg1_1 = None
    with_effects_3 = torch._higher_order_ops.effects.with_effects(getitem_4, torch.ops._TorchScriptTesting.queue_push.default, arg2_1, sin);  getitem_4 = sin = None
    getitem_6: "f32[0]" = with_effects_3[0];  with_effects_3 = None

    # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1270 in f, code: return tq.pop(), tq.pop() + tq.size(), tq
    with_effects_4 = torch._higher_order_ops.effects.with_effects(getitem_6, torch.ops._higher_order_ops.call_torchbind, arg2_1, 'pop');  getitem_6 = None
    getitem_8: "f32[0]" = with_effects_4[0]
    getitem_9: "f32[2]" = with_effects_4[1];  with_effects_4 = None
    with_effects_5 = torch._higher_order_ops.effects.with_effects(getitem_8, torch.ops._higher_order_ops.call_torchbind, arg2_1, 'pop');  getitem_8 = None
    getitem_10: "f32[0]" = with_effects_5[0]
    getitem_11: "f32[2]" = with_effects_5[1];  with_effects_5 = None
    with_effects_6 = torch._higher_order_ops.effects.with_effects(getitem_10, torch.ops._higher_order_ops.call_torchbind, arg2_1, 'size');  getitem_10 = arg2_1 = None
    getitem_12: "f32[0]" = with_effects_6[0];  with_effects_6 = None
    add_1: "f32[2]" = torch.ops.aten.add.Tensor(getitem_11, 0);  getitem_11 = None
    return (getitem_12, getitem_9, add_1)
```

In order to support this, this PR makes the following changes:
* Adds `FakeScriptObject` to `CustomObjArgument`, which will be put on the `meta["val"]` of nodes representing torchbind objects.
* Adds pickle/deepcopy support to FunctionSchema.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128397
Approved by: https://github.com/ydwu4, https://github.com/zou3519
2024-06-14 21:28:17 +00:00
afe15d2d2f Flip default value for mypy disallow_untyped_defs [3/11] (#127840)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127840
Approved by: https://github.com/oulgen
2024-06-08 18:28:01 +00:00
6412c6060c [reland] Refresh OpOverloadPacket if a new OpOverload gets added (#128000)
If a user accesses an OpOverloadPacket, then creates a new OpOverload,
then uses the OpOverloadPacket, the new OpOverload never gets hit. This
is because OpOverloadPacket caches OpOverloads when it is constructed.

This PR fixes the problem by "refreshing" the OpOverloadPacket if a new
OpOverload gets constructed and the OpOverloadPacket exists.

Test Plan:
- new tests

This is the third land attempt. The first one was reverted for breaking
internal tests, the second was reverted for being erroneously suspected
of causing a perf regression.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128000
Approved by: https://github.com/albanD
2024-06-05 17:57:09 +00:00
82a370ae3a Revert "Refresh OpOverloadPacket if a new OpOverload gets added (#126863)" (#127366)
This reverts commit ed734178abc99bc1d83ad2c61d3a1e4d4f5d20c8.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127366
Approved by: https://github.com/zou3519
2024-05-29 19:26:06 +00:00
ed734178ab Refresh OpOverloadPacket if a new OpOverload gets added (#126863)
If a user accesses an OpOverloadPacket, then creates a new OpOverload,
then uses the OpOverloadPacket, the new OpOverload never gets hit. This
is because OpOverloadPacket caches OpOverloads when it is constructed.

This PR fixes the problem by "refreshing" the OpOverloadPacket if a new
OpOverload gets constructed and the OpOverloadPacket exists.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126863
Approved by: https://github.com/albanD
2024-05-22 14:13:27 +00:00
2973c9bb88 [export] add SchemaCheckMode testing for pre-dispatch export, OpInfo (#125481)
This adds a new dispatch mode, PreDispatchSchemaCheckMode, built on top of SchemaCheckMode, used for verifying op schemas for functionalization for PreDispatch IR. More specifically, the mode runs in eager mode on concrete inputs, checking if op schemas incorrectly claim to be functional, but are aliasing or mutating. This mode is pushed to the pre-dispatch mode stack, and run before decompositions.

Current testing is hooked up to OpInfo, containing 1103 tests on 600 unique ops. Below is a list of ops that fail testing. One caveat is we only raise errors on ops that claim to be functional - if an op schema admits aliasing or mutating but fails testing for the other, it still may decompose further and become functional.

List of failed ops:
```
aten.atleast_1d.default
aten.atleast_2d.default
aten.atleast_3d.default
aten.cartesian_prod.default
aten.conj_physical.default
aten.alpha_dropout.default
aten.feature_dropout.default
aten.feature_alpha_dropout.default
aten.unsafe_chunk.default
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125481
Approved by: https://github.com/tugsbayasgalan
2024-05-14 21:07:21 +00:00
0302dc68bf [Reland] Fakify script object inputs and attributes for non-strict ex… (#125490)
A re-land of #124239.

This PR fakify ScriptObject inputs and attributes in export non-strict mode by default.

The basic idea is to only fakify the script object during tracing (i.e. aot_export). After we get the traced graph module, eagerly executing, serializing, or running more passes will use the real script objects. This is essentially treating the script object as constant tensor.

Concretely, we

fakify all the script object inputs, and module attributes (gathered by constant_attrs).
patch the module's attributes with fakified script object
right after aot_export, remove the patching (to avoid changing the original module) then modify the exported graph module's attribute to real script object.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125490
Approved by: https://github.com/angelayi
2024-05-04 02:39:42 +00:00
f1f142c44f Revert "Fakify script object inputs and attributes for non-strict export (#124239)"
This reverts commit ecc2e034f7e55bf9ff7f4e5df4e9086a5c92caaa.

Reverted https://github.com/pytorch/pytorch/pull/124239 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/124239#issuecomment-2089305447))
2024-05-01 23:56:00 +00:00
ecc2e034f7 Fakify script object inputs and attributes for non-strict export (#124239)
This PR fakify ScriptObject inputs and attributes in export non-strict mode by default.

The basic idea is to `only fakify the script object during tracing (i.e. aot_export)`. After we get the traced graph module, eagerly executing, serializing, or running more passes will use the real script objects. This is essentially treating the script object as constant tensor.

Concretely, we
1. fakify all the script object inputs, and module attributes (gathered by constant_attrs).
2. patch the module's attributes with fakified script object
3. right after aot_export, remove the patching (to avoid changing the original module) then modify the exported graph module's attribute to real script object.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124239
Approved by: https://github.com/zou3519
2024-04-30 15:57:25 +00:00
609c958281 Fix mypy issues in fake_tensor.py (#124428)
fake_tensor.py had mypy error ignored. That seems less than desirable.

Also added SafePyObjectT<T> which is a tagged wrapper around a SafePyObject but provides static type checking (with no other guarantees).

Used `SafePyObjectT<TorchDispatchModeKey>` on some of the TorchDispatchModeTLS API to ensure that we don't accidentally inject a different type than expected into the stack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124428
Approved by: https://github.com/malfet
2024-04-26 15:35:53 +00:00
35a82d4a4a Revert "Refresh OpOverloadPacket if a new OpOverload gets added (#124654)"
This reverts commit 872eeb0d7deebb58915289756d8c786f68630547.

Reverted https://github.com/pytorch/pytorch/pull/124654 on behalf of https://github.com/jeanschmidt due to Broken lots of internal signals, check D56571345 for more details ([comment](https://github.com/pytorch/pytorch/pull/124654#issuecomment-2078940680))
2024-04-26 08:56:03 +00:00
f131c2c199 Revert "Fix mypy issues in fake_tensor.py (#124428)"
This reverts commit 25c0d3f3f0b19b7ca88bc92e9dc56e391d18e010.

Reverted https://github.com/pytorch/pytorch/pull/124428 on behalf of https://github.com/jeanschmidt due to Unfortunately, I needed to revert #123735 and this one depends on it. So please check if there are no merge conflicts or breakages and feel free to merge this PR again ([comment](https://github.com/pytorch/pytorch/pull/124428#issuecomment-2078699836))
2024-04-26 06:15:17 +00:00
25c0d3f3f0 Fix mypy issues in fake_tensor.py (#124428)
fake_tensor.py had mypy error ignored. That seems less than desirable.

Also added SafePyObjectT<T> which is a tagged wrapper around a SafePyObject but provides static type checking (with no other guarantees).

Used `SafePyObjectT<TorchDispatchModeKey>` on some of the TorchDispatchModeTLS API to ensure that we don't accidentally inject a different type than expected into the stack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124428
Approved by: https://github.com/malfet
2024-04-25 14:07:53 +00:00
872eeb0d7d Refresh OpOverloadPacket if a new OpOverload gets added (#124654)
If a user accesses an OpOverloadPacket, then creates a new OpOverload,
then uses the OpOverloadPacket, the new OpOverload never gets hit. This
is because OpOverloadPacket caches OpOverloads when it is constructed.

This PR fixes the problem by "refreshing" the OpOverloadPacket if a new
OpOverload gets constructed and the OpOverloadPacket exists.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124654
Approved by: https://github.com/albanD
2024-04-24 19:30:52 +00:00
293f756cdc Support aot_export torchbind op (#123370)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123370
Approved by: https://github.com/zou3519
ghstack dependencies: #123367
2024-04-19 17:17:27 +00:00
e62169a8fa Support torchbind op dispatch in python (#123367)
We override the `__call__` method and register fake, functional, proxy default dispatch mode implementation in its python_key_mode_table.

The idea is:
1. when inputs contains FakeScriptObject,  we dispatch it through _get_dispatch mechanism. We implement dispatch mode keys automatically in the operator's constructor.
2. when inputs are not fakified, we dispatch through the original c++ dispatcher.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123367
Approved by: https://github.com/zou3519
2024-04-19 17:17:27 +00:00
648c39c47d Add OpOverload.redispatch; use it in new custom ops API (#124089)
A kernel has "dispatcher convention" if there is an additional keyset
arg at the beginning of the argument list. This PR:
- adds a way to register kernels with dispatcher_convention using
  Library.impl (pass dispatcher_convention = True)
- adds OpOverload.redispatch

We use both of the above in the new custom ops API: we register the
autograd kernel in dispatcher convention so that we can actually call
redispatch like how pytorch built-in ops do it.

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124089
Approved by: https://github.com/albanD
ghstack dependencies: #123937, #124064, #124065, #124066, #124071
2024-04-18 12:48:04 +00:00
645173a0b5 Add torch.library.register_autograd (#124071)
Allows registering autograd for all custom op entry points:
- the new-style custom op API (custom_op)
- the old-style torch.library APIs
- C++ operator registration

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124071
Approved by: https://github.com/albanD
ghstack dependencies: #123937, #124064, #124065, #124066
2024-04-18 12:47:59 +00:00
27daa110c8 Back out "Refresh OpOverloadPacket if a new OpOverload gets added (#123578)" (#124324)
Summary:
Original commit changeset: 528276bc8a92

Original Phabricator Diff: D56057952

Differential Revision: D56271240

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124324
Approved by: https://github.com/davidberard98
2024-04-18 03:33:54 +00:00
93e249969b [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261)
Remove useless parentheses in `raise` statements if the exception type is raised with no argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261
Approved by: https://github.com/albanD
2024-04-17 19:29:34 +00:00
1b4419dc4d Refresh OpOverloadPacket if a new OpOverload gets added (#123578)
If a user accesses an OpOverloadPacket, then creates a new OpOverload,
then uses the OpOverloadPacket, the new OpOverload never gets hit. This
is because OpOverloadPacket caches OpOverloads when it is constructed.

This PR fixes the problem by "refreshing" the OpOverloadPacket if a new
OpOverload gets constructed and the OpOverloadPacket exists.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123578
Approved by: https://github.com/albanD
ghstack dependencies: #123453
2024-04-11 13:18:06 +00:00
8a0436014d Support map in pre-dispatch functionalization (#121444)
When we enter map_autograd, we try to trace through fwd/bwd of a map operator that is wrapped in ctx.functionalize wrapper. This forces us to go through PreDispatch functionalization again (only the python part). As a result, it revealed our previous bug where pre-dispatch mode handling doesn't actually manage the local dispatch key set. (If there is no active mode, we need to turn off PreDispatch key). This PR fixes that. Also I shuffled some APIs around so that there is less code duplication as the setting/unsetting logic is quite hard to get it right.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121444
Approved by: https://github.com/bdhirsh
2024-04-03 17:14:41 +00:00
25ad90adc0 Revert "Support map in pre-dispatch functionalization (#121444)"
This reverts commit 9288b274611abc904a67d9cb02c837aa2cb769fd.

Reverted https://github.com/pytorch/pytorch/pull/121444 on behalf of https://github.com/atalman due to New test test_aot_export_predispatch_map_1 is failing on windows ([comment](https://github.com/pytorch/pytorch/pull/121444#issuecomment-2034526949))
2024-04-03 12:55:23 +00:00
9288b27461 Support map in pre-dispatch functionalization (#121444)
When we enter map_autograd, we try to trace through fwd/bwd of a map operator that is wrapped in ctx.functionalize wrapper. This forces us to go through PreDispatch functionalization again (only the python part). As a result, it revealed our previous bug where pre-dispatch mode handling doesn't actually manage the local dispatch key set. (If there is no active mode, we need to turn off PreDispatch key). This PR fixes that. Also I shuffled some APIs around so that there is less code duplication as the setting/unsetting logic is quite hard to get it right.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121444
Approved by: https://github.com/bdhirsh
2024-04-03 03:28:14 +00:00
6b8205d3de Revert "Support map in pre-dispatch functionalization (#121444)"
This reverts commit 079feea3379c021a330dbfac7668a5fc8fccc3bd.

Reverted https://github.com/pytorch/pytorch/pull/121444 on behalf of https://github.com/clee2000 due to sorry windows failure seems related 079feea337 https://github.com/pytorch/pytorch/actions/runs/8474191301/job/23220791555. PR got force merged before windows job finished ([comment](https://github.com/pytorch/pytorch/pull/121444#issuecomment-2026323614))
2024-03-28 23:42:26 +00:00
079feea337 Support map in pre-dispatch functionalization (#121444)
When we enter map_autograd, we try to trace through fwd/bwd of a map operator that is wrapped in ctx.functionalize wrapper. This forces us to go through PreDispatch functionalization again (only the python part). As a result, it revealed our previous bug where pre-dispatch mode handling doesn't actually manage the local dispatch key set. (If there is no active mode, we need to turn off PreDispatch key). This PR fixes that. Also I shuffled some APIs around so that there is less code duplication as the setting/unsetting logic is quite hard to get it right.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121444
Approved by: https://github.com/bdhirsh
2024-03-28 21:56:36 +00:00
443e241cc5 Don't cache predispatch kernels (#121712)
Summary: Title

Test Plan: CI

Differential Revision: D54791087

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121712
Approved by: https://github.com/ydwu4
2024-03-12 18:05:59 +00:00
3ef0befdc9 Better error messages for impl_abstract_pystub (#120959)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120959
Approved by: https://github.com/drisspg
2024-03-04 15:24:36 +00:00
c646030cd2 Support higher order op functionalization in predispatch IR (#115314)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115314
Approved by: https://github.com/bdhirsh
2024-03-01 09:13:47 +00:00
46712b019d Enable local_partial_types (#118467)
When using dmypy, this setting is enabled and cannot be turned off. Force it for regular mypy too.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118467
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418, #118432
2024-01-28 13:38:22 +00:00
2de3474711 Simplify kwargs propagation in __call__. (#117880)
In case no keyword arguments are passed, `**kwargs` would expand just fine without the need for extra overhead of `or {}`. In addition to reducing boilerplate, this also comes with a small perf improvement:
```
In [1]: def null(*args, **kwargs):
   ...:     pass
   ...:

In [2]: def call1(*args, **kwargs):
   ...:     return null(*args, **(kwargs or {}))
   ...:

In [3]: def call2(*args, **kwargs):
   ...:     return null(*args, **kwargs)
   ...:

In [4]: %timeit call1()
145 ns ± 2.07 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [5]: %timeit call2()
118 ns ± 2.14 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [6]: %timeit call1()
147 ns ± 6.19 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [7]: %timeit call2()
117 ns ± 0.846 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117880
Approved by: https://github.com/Skylion007
2024-01-20 19:29:35 +00:00
76b1d44d57 pre_dispatch aot_export (#115188)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115188
Approved by: https://github.com/bdhirsh
2023-12-25 04:51:21 +00:00
0567f71ac6 Revert " pre_dispatch aot_export (#115188)"
This reverts commit a267d6735051a4714fa2ac1c163315b650118744.

Reverted https://github.com/pytorch/pytorch/pull/115188 on behalf of https://github.com/jeanschmidt due to sadly, it is required to revert this commit in order to revert https://github.com/pytorch/pytorch/pull/115454 ([comment](https://github.com/pytorch/pytorch/pull/115188#issuecomment-1866310014))
2023-12-21 14:03:18 +00:00
a267d67350 pre_dispatch aot_export (#115188)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115188
Approved by: https://github.com/bdhirsh
2023-12-20 21:36:25 +00:00
d85314c95c Support Predispatch functionalization (#113728)
In this PR, we are implementing Functionalization on pre-dispatch graph. Today, every dispatch key except for Dispatchkey.Python has a dedicated mode stack in python. PreDispatch tracing relies on this behaviour by pushing ProxyTorchDispatchMode to Dispatchkey.PreDispatch mode stack and handle the dispatching logic in python. To make pre-dispatch functionalization work, we now need to push FunctionalTensorMode on DispatchKey.PreDispatch mode stack and make sure it runs before ProxyTorchDispatchMode. (this is very similar to how post-dispatch tracing work). Here are some design decisions we made for this flow to work:

1. FunctionalTensorMode internally calls C++ functionalize key. Since C++ functionalization goes after PreDispatch, if we are not careful, we will keep re-entering into PreDispatch key. We solve this by directly dispatching to C++ Functionalize key.

2. We delete mode_stack_per_key logic because the only realistic time it is exercised is for PreDispatch and it is in general not safe to have a plain list because FunctionalTensorMode and ProxyTorchDispatchMode ordering matter and it is hard to enforce it on plain list. Instead, now we have a private class that tracks PreDispatch mode stack.

3.  We will still run CompositeImplicitAutograd decomps in this PR, and disable this logic later as a followup.

Some missing bits after this PR:
1. Preserving autograd ops in a functional form. Right now they still show up in the graph but in a "non-functional" way.
2. Turn off CompositeImplicitAutograd decomps
3. Functionalizing HOO

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113728
Approved by: https://github.com/bdhirsh
2023-12-19 20:28:35 +00:00
c5dcb50c00 [easy] aten ops: support passing all args as kwargs, including self (#114920)
Summary:
This is important for writing aten IR based graph transformation.

```
In [4]: [x.name for x in torch.ops.aten.reshape.default._schema.arguments]
Out[4]: ['self', 'shape']

In [8]: torch.ops.aten.reshape.default(torch.rand(1,2), shape=[2])
Out[8]: tensor([0.7584, 0.4834])

# === CANNOT CALL `self` BY KWARGS ===

In [7]: torch.ops.aten.reshape.default(self=torch.rand(1,2), shape=[2])
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[7], line 1
----> 1 torch.ops.aten.reshape.default(self=torch.rand(1,2), shape=[2])

TypeError: OpOverload.__call__() got multiple values for argument 'self'

```

# Where's the problem?

1. the aten ops first arg is usually named `self` (aten/src/ATen/native/native_functions.yaml)
2. Unfortunately, in `torch._ops.{OpOverload, OpOverloadPacket}.__call__()`, the first arg is (by python convention) named `self` too.

So when call `self` by kwargs, `OpOverloadPacket.__call__` received:

```
OpOverloadPacket.__call__(self, {"self": ...})
```

It is Python that does not allow some argument named "arg" to appear twice. and hence

> TypeError: OpOverload.__call__() got multiple values for argument 'self'

# How to fix?

**Note that**, in above, `self` is an instance of `OpOverloadPacket`, and the "self" kwarg is the input tensor to the aten op. To fix, we only need to differentiate the two `self`s.

In Python, first arg of a method does not need to be named `self`. So we change the `__call__` definition to:

```
def __call__(_self, ...):
```

Now the call becomes:

```
OpOverloadPacket.__call__(_self, {"self": ...})
```

where:
* `_self` is the instance to the `OpOverloadPacket`
* `"self"` is the input tensor to the aten op.

Test Plan:
```
In [4]: [x.name for x in torch.ops.aten.reshape.default._schema.arguments]
Out[4]: ['self', 'shape']

In [3]: torch.ops.aten.reshape.default(self=torch.rand(1,2), shape=[2])
Out[3]: tensor([0.5127, 0.3051])
```

Differential Revision: D51731996

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114920
Approved by: https://github.com/houseroad
2023-12-16 18:32:58 +00:00
cfa4370c07 torch.compile should auto-functionalize certain mutable ops (#114955)
Users may wish to torch.compile custom ops that mutate their inputs
and return nothing (this is a common class of operators).
torch.compile will automatically support this op without anyone needing
to provide a functionalization kernel for it. Here's how.

Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> ()
op. First, when FakeTensor sees this op, it can just return None.
This is the case because custom ops are not allowed to mutate input
metadata, so the FakeTensor rule for one that returns nothing is trivial.

Next, when Python FunctionalTensor sees the op, it will functionalize
it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...})
HOP and replacing the mutated inputs with the outputs of this HOP.
This HOP effectively runs the functional version of the op when
called: it clones inputs that will be mutated, runs the op, and
then returns Tensors with the new values.

In the future we can teach Inductor how to do re-inplacing when it sees
this HOP (like how triton kernels do it) but this isn't urgent (and is
more of a performance problem).

Test Plan:
- new tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114955
Approved by: https://github.com/bdhirsh
2023-12-05 14:53:08 +00:00
55064a4ef9 [BE] add parentheses to kwargs unpacking func(*args, **(kwargs or {})) (#115026)
This PR adds parentheses to kwargs unpacking `func(*args, **(kwargs or {}))` for better code readability.

With/without the parentheses are semantic equivalent because they produce the same bytecode.

```console
$ echo "func(*args, **kwargs or {})" | python3 -m dis -
  0           0 RESUME                   0

  1           2 PUSH_NULL
              4 LOAD_NAME                0 (func)
              6 LOAD_NAME                1 (args)
              8 BUILD_MAP                0
             10 LOAD_NAME                2 (kwargs)
             12 JUMP_IF_TRUE_OR_POP      1 (to 16)
             14 BUILD_MAP                0
        >>   16 DICT_MERGE               1
             18 CALL_FUNCTION_EX         1
             20 POP_TOP
             22 LOAD_CONST               0 (None)
             24 RETURN_VALUE

$ echo "func(*args, **(kwargs or {}))" | python3 -m dis -
  0           0 RESUME                   0

  1           2 PUSH_NULL
              4 LOAD_NAME                0 (func)
              6 LOAD_NAME                1 (args)
              8 BUILD_MAP                0
             10 LOAD_NAME                2 (kwargs)
             12 JUMP_IF_TRUE_OR_POP      1 (to 16)
             14 BUILD_MAP                0
        >>   16 DICT_MERGE               1
             18 CALL_FUNCTION_EX         1
             20 POP_TOP
             22 LOAD_CONST               0 (None)
             24 RETURN_VALUE
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115026
Approved by: https://github.com/Skylion007
2023-12-03 20:03:26 +00:00