Commit Graph

171 Commits

Author SHA1 Message Date
80cf0ce153 Enhance torch.vmap support from inside torch.compile (#116050)
This work rewrites vmap support in torch.compile by inlining most of
the frames into the existing FX graph. It also unlocks to PyTorch to
support features that were previously missing, such as keyword args.

Fixes: https://github.com/pytorch/pytorch/issues/114306

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116050
Approved by: https://github.com/zou3519
2024-01-22 17:53:45 +00:00
eba5d5485d [dynamo] make ConstantSource propagate through built-in ops for TensorVariable (#117704)
Fixes #117685.

This PR only makes ConstantSource perserved for built-in ops when we find all the inputs are either constant tensors or python constants.

 It doesn't fundamentally solve the problem of preserving ConstantSource information through all operators that's potentially can be constant folded.

For the following code in the issue:
```
class Bob(torch.nn.Module):
    def __init__(self, p, val) -> None:
        super().__init__()
        self.p = p
        self.y = torch.nn.Parameter(torch.tensor(val))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # This only looks dynamic but it's actually a constant value
        if get_y(self.y) < self.p:
            return torch.cat([x,x])
        else:
            return x
```
The graph exported looks like following:
```python
class GraphModule(torch.nn.Module):
    def forward(self, x):
        arg0: "f32[s0, s1]";

        arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
        l_x_ = arg0

        # File: /home/yidi/local/pytorch/test/dynamo/test_export.py:1498 in forward, code: return torch.cat([x, x])
        cat = torch.cat([l_x_, l_x_]);  l_x_ = None
        return pytree.tree_unflatten([cat], self._out_spec)
```

Test Plan:
Added a new test for the given repro.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117704
Approved by: https://github.com/jansel, https://github.com/anijain2305
2024-01-18 20:18:34 +00:00
6e4e81a9ef [dynamo] Extend LazyVariableTracker to tuples (#117426)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117426
Approved by: https://github.com/lezcano, https://github.com/jansel
2024-01-18 15:51:28 +00:00
4ba5318d3f [dynamo] Add DictView variable tracker (#108420)
This also starts a comparison pattern where we don't ask variables
what's their type, but what are their capabilities.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108420
Approved by: https://github.com/jansel
ghstack dependencies: #112252, #117630, #110524
2024-01-18 09:37:33 +00:00
f4df0f061c Implement set in terms of dict (#110524)
This allows to heavily simplify the implementation of set, which was
"quite unique". Now we represent a set a as a dict where all its values
are None.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110524
Approved by: https://github.com/jansel
ghstack dependencies: #112252, #117630
2024-01-18 09:36:41 +00:00
62496ffd0d [dynamo][easy]: Add support for operator.truth (#117463)
* This is an old builtin function equivalent to the bool constructor. it is easy enough to add support for.
* I also realized the tests were in the wrong class (the one reserved for testing default args) so I moved them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117463
Approved by: https://github.com/jansel
2024-01-14 19:08:31 +00:00
bf27dd6df9 Add dynamo support for operator.abs (#117442)
A test case for operator.abs and allows for constant folding with it. Partially applies to #116396

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117442
Approved by: https://github.com/jansel, https://github.com/malfet
2024-01-13 21:38:55 +00:00
1dd4813328 [BE][dynamo]: Add operator is and is not tests to dynamo tests (#116397)
Adds an operator that was unit not tested in our test suite - improves coverage. Inspired by looking into https://github.com/pytorch/pytorch/pull/116397 after @XuehaiPan brought up some issues with builtins in #116389

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116397
Approved by: https://github.com/albanD, https://github.com/jansel
2024-01-09 21:13:22 +00:00
83e8a0721d Reland #111196 (take 4) "Support tensors as Dict keys" (#116934)
Fixes #ISSUE_NUMBER

See that PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116934
Approved by: https://github.com/ezyang, https://github.com/huydhn
2024-01-07 01:37:26 +00:00
2dca3e99eb Revert "Support tensors as Dict keys Re-PR of #111196 (#116785)"
This reverts commit 1badad9ce9694ef70f6a3dc01000f2cf310c4c11.

Reverted https://github.com/pytorch/pytorch/pull/116785 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/116785#issuecomment-1879592261))
2024-01-06 08:22:33 +00:00
1badad9ce9 Support tensors as Dict keys Re-PR of #111196 (#116785)
This prepares the PR where we implement sets in terms of dicts.
To do so, rather than storing internally a dictionary that maps literals
to VariableTrackers, it stores (pretty much) a dictionary from VTs to VTs.
To do so, keys are wrapped in an opaque internal class _Hashable.
The Hashable class is opaque on purpose so that it fails hard if
if it inadvertently leaks back into user code.
We also found and fixed a number of latent bugs and inconsistencies
in the way dynamo checked what can be a dict key. More generally, we
make much clearer what are the things that need to be modified to add
a new supported key type to Dicts.

Fixes [#107595](https://www.internalfb.com/tasks?t=107595)
Fixes [#111603](https://www.internalfb.com/tasks?t=111603)

Re-PR of https://github.com/pytorch/pytorch/pull/111196 sadly due to reverts, we could not reuse @lezcano's original PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116785
Approved by: https://github.com/mlazos
2024-01-06 03:35:35 +00:00
0159e3abbd [dynamo] add a handler for itertools_chain_from_iterable and test (#116849)
1. add a handler for itertools_chain_from_iterable
2. a test for itertools_chain_from_iterable

Fixes #116463

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116849
Approved by: https://github.com/ezyang
2024-01-05 15:14:18 +00:00
bd10fea79a [BE]: Enable F821 and fix bugs (#116579)
Fixes #112371

I tried to fix as many of the bugs as I could, a few I could not figure out what the proper fix for them was though and so I left them with noqas.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116579
Approved by: https://github.com/ezyang
2024-01-01 08:40:46 +00:00
3149e4a667 [dynamo] fix sum() function with start argument (#116389)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116389
Approved by: https://github.com/Skylion007, https://github.com/malfet
2023-12-27 20:42:27 +00:00
e0e90bc0d4 Revert "[dynamo] fix sum() function with start argument (#116389)"
This reverts commit 3c9076f070fab5b27eae3b7846755c98b7c97a1a.

Reverted https://github.com/pytorch/pytorch/pull/116389 on behalf of https://github.com/kit1980 due to Breaks Meta-internal tests, but the issue could have been caught on GitHub ([comment](https://github.com/pytorch/pytorch/pull/116389#issuecomment-1870556927))
2023-12-27 19:05:55 +00:00
f657b2b1f8 [Dynamo][10/N] Remove TorchVariable and is_allowed (#116312)
After this refactor:
* ```TorchVariable``` definition and all references are removed.
* All ```is_allowed``` references except one are removed.
  - The only left one is in ```torch/_dynamo/decorators:_disallow_in_graph_helper```. It was called when users put ```disallow_in_graph``` decorator on a function. Since we use the lists in ```trace_rules``` to decide the function's trace rule, so the decorator would only be used as customer function rather than torch functions. I'll defer this to a separate decorator refactor PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116312
Approved by: https://github.com/jansel
2023-12-27 18:47:05 +00:00
3b709d7c1e Revert "[Dynamo][10/N] Remove TorchVariable and is_allowed (#116312)"
This reverts commit 015bd0e0a189f929e469c6bc75fe1541c18a014d.

Reverted https://github.com/pytorch/pytorch/pull/116312 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/116312#issuecomment-1869825506))
2023-12-26 23:47:15 +00:00
3c9076f070 [dynamo] fix sum() function with start argument (#116389)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116389
Approved by: https://github.com/Skylion007
2023-12-26 06:37:55 +00:00
039fbeb016 [dynamo] fix functools.reduce() function with None as initial (#116398)
The `initial` argument in `functools.reduce` can be `None`.

```python
initial_missing = object()

def reduce(function, iterable, initial=initial_missing, /):
    it = iter(iterable)
    if initial is initial_missing:
        value = next(it)
    else:
        value = initial
    for element in it:
        value = function(value, element)
    return value
```

Reference:

- python/cpython#102759

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116398
Approved by: https://github.com/Skylion007
2023-12-25 21:23:28 +00:00
015bd0e0a1 [Dynamo][10/N] Remove TorchVariable and is_allowed (#116312)
After this refactor:
* ```TorchVariable``` definition and all references are removed.
* All ```is_allowed``` references except one are removed.
  - The only left one is in ```torch/_dynamo/decorators:_disallow_in_graph_helper```. It was called when users put ```disallow_in_graph``` decorator on a function. Since we use the lists in ```trace_rules``` to decide the function's trace rule, so the decorator would only be used as customer function rather than torch functions. I'll defer this to a separate decorator refactor PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116312
Approved by: https://github.com/jansel
2023-12-23 09:44:09 +00:00
505a9e4854 add support for dynamic shapes in round (#115259)
Fixes #114310 and supersedes #114748.

There are two reasons why we have quite a few special cases for `round`:

1. `round` is actually two ops. With `ndigits=None` (default), `round` always returns an integer. When `ndigits` is an integer, the returned type is a float.
2. Although `round` takes two arguments, it is a unary function with a parameter rather than a binary one.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115259
Approved by: https://github.com/peterbell10, https://github.com/lezcano
2023-12-19 15:45:50 +00:00
b06b02559e Support non grapharg and intermediary grad access (#115898)
Support for something we need for both FSDP and optimizers. For sourced args that are not inputs (params, etc) - we use the dynamic_getattr flow on tensors. This soundly handles the storage and registration and guarding downstream of tensor_wrap for the grad values. For non sourced (true intermediates), we only support None (the idea being that if we have a true intermediate in the graph with grad, we are already doing something weird).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115898
Approved by: https://github.com/bdhirsh
ghstack dependencies: #115315, #112184
2023-12-16 18:43:37 +00:00
fbeca60b1f Remove replace_all and make VTs mutable (#113725)
1.  Removes calls to `replace_all` and `clone` and makes VTs mutable.
2. Properly handles Tuple Iterator mutation. Previously TupleIterator variables would only be properly reconstructed if they were advanced at least once in a frame. On calls to `next`, the source information would be lost (due to constructing a new iterator without using builder), which would ensure that during codegen the variable would be reconstructed from scratch. Now that VTs are mutated, the source is never lost, so we need to properly track mutation and handle it by replaying calls to `next` at the end of the modified bytecode.
3. Added test for checking iadd side effects, this was missing in our unit test coverage.
4. Fixed two incorrect sources, DelayGraphBreakVariable, and UserMethodVariable both relied on setting the source to AttrSource(parent, name) at the callsite of `var_getattr`.
5. Fixed a bug in inplace adding for lists, it would set the resulting VariableTracker's source to `None` which would utilize a different reconstruct path in codegen. Now this is handled explicitly by reconstructing vars when allow_cache=`False`, so that during side effect replay, the mutated var is correctly updated.

In subsequent PRs:
* Refactoring side effect tracking to be significantly simpler (I think we only need an `is_modified` flag)
* Refactor `next_variables` iterator to match the signature of `next`
* Remove all references to `options` in the code
* Refactor VTs representing mutable collections to implement their own mutation update handling
* Remove clone and/or make it specific to lists for creating slices
* Add mutation tracking/replay for sets
* Add mutation tracking/replay for iter.py
* Removing setting source in builder (it's set at the top level after a var is returned)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113725
Approved by: https://github.com/jansel
2023-12-10 09:31:21 +00:00
da341d0d48 [Dynamo][6.1/N] Refactor out TorchInGraphFunctionVariable and improve heuristic (#113432)
This is splitted from #113009, please check https://github.com/pytorch/pytorch/pull/113009#issuecomment-1804417925 for more details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113432
Approved by: https://github.com/ezyang, https://github.com/jansel
2023-12-09 05:11:44 +00:00
e8e4141773 Revert "[Dynamo][6.1/N] Refactor out TorchInGraphFunctionVariable and improve heuristic (#113432)"
This reverts commit e61d6b42f0f4e4fa5bb816e03fb81e5bbcc9fa06.

Reverted https://github.com/pytorch/pytorch/pull/113432 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but it is failing dynamo tests in trunk e61d6b42f0, landrace? ([comment](https://github.com/pytorch/pytorch/pull/113432#issuecomment-1847787981))
2023-12-08 20:15:39 +00:00
e61d6b42f0 [Dynamo][6.1/N] Refactor out TorchInGraphFunctionVariable and improve heuristic (#113432)
This is splitted from #113009, please check https://github.com/pytorch/pytorch/pull/113009#issuecomment-1804417925 for more details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113432
Approved by: https://github.com/ezyang, https://github.com/jansel
2023-12-08 17:15:14 +00:00
f4c67ffff4 [dynamo] Improve support for dynamic shapes str.format and _assert (#115203)
This removes a graph break in vision_maskrcnn.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115203
Approved by: https://github.com/yanboliang
2023-12-06 04:54:45 +00:00
4620170008 [Dynamo] Revert multiple PRs since they triggered compilation stuck internally (#115126)
Revert the following PRs to mitigate internal compilation stuck:
#113432
#114016
#114507
#114196
#114739
#114669

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115126
Approved by: https://github.com/xush6528
2023-12-05 22:35:37 +00:00
4b8ddbbc7e [dynamo] Improve graph break message for copy.deepcopy (#115120)
I was curious what hf_T5_generate was trying to deepcopy, so I updated the errror message:
Before:
```
STATS graph_break
  ("'skip function deepcopy in file /home/jansel/conda/envs/pytorch/lib/python3.10/copy.py'', skipped according skipfiles.SKIP_DIRS'", 3)
  ...
```
After:
```
STATS graph_break
  ('copy.deepcopy UserDefinedObjectVariable(GenerationConfig)', 3)
  ...
```

Related issue: #115122

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115120
Approved by: https://github.com/oulgen
ghstack dependencies: #115095, #115046, #115057, #115119
2023-12-05 19:01:31 +00:00
522bae20df [dynamo] Support any() on SymNodeVariable (#115119)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115119
Approved by: https://github.com/yanboliang
ghstack dependencies: #115095, #115046, #115057
2023-12-05 19:01:31 +00:00
2e8ac5ea93 [dynamo] support dict.fromkeys() / OrderedDict.fromkeys() / defaultdict.fromkeys() (#115010)
Add support for `dict.fromkeys`, `OrderedDict.fromkeys`, and `defaultdict.fromkeys`.

Fixes #114963

- #114963

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115010
Approved by: https://github.com/jansel
2023-12-04 01:49:59 +00:00
4cfe997490 [dynamo] handle setting .data on a tensor (#113080)
**Dynamo**

We don't want setattr in the graph. Setting data has interesting implications on both aliasing and on the autograd engine.

The safe recipe is:

1) Disable grad
2) Call set_()
3) Manually lower the version counter on the object to hide it from the autograd engine

This is effectively the same exact thing as setting .data, and it composes properly with aot_autograd and inductor.

**aot_autograd**

For aot_autograd, there's another snag.

Specifically, when we invoke aot_autograd, we call `fake_mode.from_tensor()`, relying on memo to get the right tensor out. For .data mutations, this doesn't work, because the memoized fake_tensor is in the state it will be in at the end of the trace, not at the beginning. This means that the .data call is already applied, and the tensor shape (as in the case of these tests) mismatches. aot_autograd produces an invalid graph, with illegal calls like `torch.ops.aten.view.default(primals_2, [0])` where primals is actually sized `([6])` on input.

The new plan here is to:
1) Record tensor fakification policy in dynamo
2) provide a fresh fake mode to all backends
3) Invoke from_tensor with the stored policy to get fresh new fake tensors in aot_autograd

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113080
Approved by: https://github.com/bdhirsh
2023-12-02 00:35:44 +00:00
172a103857 [dynamo] strict=True kwarg for zip (#114047)
Fixes https://github.com/pytorch/pytorch/issues/113894

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114047
Approved by: https://github.com/ezyang
2023-11-22 08:48:51 +00:00
9e657ce2ed [HigherOrderOp] set should_flatten_output=True for cond (#113819)
This PR add should_flatten_outpu=True for cond. This effectively allows cond to support pytree output with the output being flattened. Note: a single tensor output will be automatically casted as tuple for torch.ops.higher_order.cond.

This PR also adds support for comparing BuiltinVariables e.g. tuple, this is to make sure we could make dynamo inline comparing two tree_spec to make sure both branches returns the same tree_spec.

Test Plan:
Existing tests. Will add more pytree tests and modify the documentations in the follow-up prs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113819
Approved by: https://github.com/zou3519
2023-11-22 04:06:30 +00:00
033d7b670a [Dynamo][6.1/N] Refactor out TorchInGraphFunctionVariable and improve heuristic (#113432)
This is splitted from #113009, please check https://github.com/pytorch/pytorch/pull/113009#issuecomment-1804417925 for more details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113432
Approved by: https://github.com/ezyang
2023-11-17 23:42:00 +00:00
c94fdebd3e [dynamo] chore: Fallback on const_handler instead of special-casing on ConstantVariable (#113893)
Fixes https://github.com/pytorch/pytorch/pull/113874#issuecomment-1815269686

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113893
Approved by: https://github.com/ezyang
2023-11-17 07:46:58 +00:00
277229d0c6 [dynamo] Fix incorrectly casting SymNode to int when input is bool (#113871)
Fixes https://github.com/pytorch/pytorch/issues/113393, https://github.com/pytorch/pytorch/pull/113848#issuecomment-1814624510

Incorrectly casting symnode type will cause it to take the wrong path in symbolic_shapes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113871
Approved by: https://github.com/jansel
2023-11-16 23:24:57 +00:00
5d170fce29 Revert "Support tensors as Dict keys (#111196)"
This reverts commit b0805fa5d0f73f3419129b1606a3e9a58eed2768.

Reverted https://github.com/pytorch/pytorch/pull/111196 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but it is failing internally. I will provide the details there ([comment](https://github.com/pytorch/pytorch/pull/111196#issuecomment-1813410149))
2023-11-15 23:08:00 +00:00
9146ca6a07 use sourceless builder for builtin getattr (#113340)
In TorchVision we use the following (simplified) dispatch mechanism:

```python
import torch

def kernel1(tensor):
    return tensor + 2

def dispatcher1(input):
    kernel = get_kernel(dispatcher1, type(input))
    return kernel(input)

def kernel2(tensor):
    return tensor - 2

def dispatcher2(input):
    kernel = get_kernel(dispatcher2, type(input))
    return kernel(input)

# We actually use the function and type as keys, rather than their names.
# However, this currently not supported, but should be easy to add after
# https://github.com/pytorch/pytorch/pull/111196
REGISTRY = {
    "dispatcher1": {"Tensor": kernel1},
    "dispatcher2": {"Tensor": kernel2},
}

def get_kernel(dispatcher, input_type):
    dispatcher_registry = REGISTRY[dispatcher.__name__]
    for cls in input_type.__mro__:
        kernel = dispatcher_registry[cls.__name__]
        break
    return kernel
```

This can be compiled without graph breaks:

```python
cfn = torch.compile(dispatcher1, fullgraph=True)
torch.testing.assert_close(int(cfn(torch.tensor(3))), 5)

cfn = torch.compile(dispatcher2, fullgraph=True)
torch.testing.assert_close(int(cfn(torch.tensor(3))), 1)
```

However, if we start chaining these calls, we hit some issues:

```python
class Pipeline(torch.nn.Module):
    def forward(self, input):
        input = dispatcher1(input)
        input = dispatcher2(input)
        return input

cfn = torch.compile(Pipeline(), fullgraph=True)
torch.testing.assert_close(int(cfn(torch.tensor(3))), 3)
```

```
Can't access members of type(obj) for a generated custom object. Please use __class__ instead
```

The error message is not really helpful here. The following happens: when compiling `dispatcher1`, `get_kernel` gets inlined. That means when hitting `dispatcher2`, the `type` call no longer happens on an input with a source. Thus, in the first iteration we hit the top branch, while in the second we hit the bottom:

addb8e29cd/torch/_dynamo/variables/builtin.py (L1264-L1268)

And the error message I posted above originates from the type being treated as constant. This PR replaces this with a `SourcelessBuilder` instead.

With that fix in place, we hit another pointing to `input_type.__mro__`

```
AssertionError: Consider SourcelessBuilder for ephemeral objects, usually objects created locally.
```

Fix is similar: instead of using a `VariableBuilder` here, we use a `SourcelessBuilder` in case we have no `source`:

addb8e29cd/torch/_dynamo/variables/builtin.py (L1167-L1168)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113340
Approved by: https://github.com/peterbell10, https://github.com/lezcano
2023-11-15 13:01:20 +00:00
77f66ade66 Revert "use sourceless builder for builtin getattr (#113340)"
This reverts commit d64bc8f0f81bd9b514eb1a5ee6f5b03094e4e6e9.

Reverted https://github.com/pytorch/pytorch/pull/113340 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but the test is failing internally ([comment](https://github.com/pytorch/pytorch/pull/113340#issuecomment-1811684167))
2023-11-15 02:06:00 +00:00
b0805fa5d0 Support tensors as Dict keys (#111196)
This prepares the PR where we implement sets in terms of dicts.
To do so, rather than storing internally a dictionary that maps literals
to VariableTrackers, it stores (pretty much) a dictionary from VTs to VTs.
To do so, keys are wrapped in an opaque internal class `_Hashable`.
The Hashable class is opaque on purpose so that it fails hard if
if it inadvertently leaks back into user code.

We also found and fixed a number of latent bugs and inconsistencies
in the way dynamo checked what can be a dict key. More generally, we
make much clearer what are the things that need to be modified to add
a new supported key type to Dicts.

Fixes https://github.com/pytorch/pytorch/issues/107595
Fixes https://github.com/pytorch/pytorch/issues/111603
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111196
Approved by: https://github.com/jansel
2023-11-14 19:14:03 +00:00
d64bc8f0f8 use sourceless builder for builtin getattr (#113340)
In TorchVision we use the following (simplified) dispatch mechanism:

```python
import torch

def kernel1(tensor):
    return tensor + 2

def dispatcher1(input):
    kernel = get_kernel(dispatcher1, type(input))
    return kernel(input)

def kernel2(tensor):
    return tensor - 2

def dispatcher2(input):
    kernel = get_kernel(dispatcher2, type(input))
    return kernel(input)

# We actually use the function and type as keys, rather than their names.
# However, this currently not supported, but should be easy to add after
# https://github.com/pytorch/pytorch/pull/111196
REGISTRY = {
    "dispatcher1": {"Tensor": kernel1},
    "dispatcher2": {"Tensor": kernel2},
}

def get_kernel(dispatcher, input_type):
    dispatcher_registry = REGISTRY[dispatcher.__name__]
    for cls in input_type.__mro__:
        kernel = dispatcher_registry[cls.__name__]
        break
    return kernel
```

This can be compiled without graph breaks:

```python
cfn = torch.compile(dispatcher1, fullgraph=True)
torch.testing.assert_close(int(cfn(torch.tensor(3))), 5)

cfn = torch.compile(dispatcher2, fullgraph=True)
torch.testing.assert_close(int(cfn(torch.tensor(3))), 1)
```

However, if we start chaining these calls, we hit some issues:

```python
class Pipeline(torch.nn.Module):
    def forward(self, input):
        input = dispatcher1(input)
        input = dispatcher2(input)
        return input

cfn = torch.compile(Pipeline(), fullgraph=True)
torch.testing.assert_close(int(cfn(torch.tensor(3))), 3)
```

```
Can't access members of type(obj) for a generated custom object. Please use __class__ instead
```

The error message is not really helpful here. The following happens: when compiling `dispatcher1`, `get_kernel` gets inlined. That means when hitting `dispatcher2`, the `type` call no longer happens on an input with a source. Thus, in the first iteration we hit the top branch, while in the second we hit the bottom:

addb8e29cd/torch/_dynamo/variables/builtin.py (L1264-L1268)

And the error message I posted above originates from the type being treated as constant. This PR replaces this with a `SourcelessBuilder` instead.

With that fix in place, we hit another pointing to `input_type.__mro__`

```
AssertionError: Consider SourcelessBuilder for ephemeral objects, usually objects created locally.
```

Fix is similar: instead of using a `VariableBuilder` here, we use a `SourcelessBuilder` in case we have no `source`:

addb8e29cd/torch/_dynamo/variables/builtin.py (L1167-L1168)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113340
Approved by: https://github.com/peterbell10, https://github.com/lezcano
2023-11-13 14:29:17 +00:00
0f7ac2635d Uniformly use SourcelessBuilder to handle user defined types (#113390)
Subsumes https://github.com/pytorch/pytorch/pull/110794

Fixes https://github.com/pytorch/pytorch/issues/110315

This is not really a 100% sound fix, a deeper analysis of the bug can be found at https://docs.google.com/document/d/1y-nRAPdbZEji52MPKYzC0U3VhvW9yEAEDqP5t5GhWZ0/edit

The general idea behind the fix here is that we are going to play fast and loose with user defined classes: as Dynamo is written today, we are willing to pull out these types and directly manipulate them (e.g., look at their `__mro__`, etc) without an intervening VariableTracker. As such, if I use `python_type` to extract out the Python type of a VT or if I am manually reading out the `__bases__` of a type, which may be a user defined class, if it is sourceless, all I need to do is use SourcelessBuilder instead of ConstantVariable to make sure I wrap it into the correct VT class.

The approach in https://github.com/pytorch/pytorch/pull/110794 was "more correct", but we'd have to go substantially further to get it all working. So I am doing this to unblock suo for now.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113390
Approved by: https://github.com/suo
2023-11-10 07:26:52 +00:00
cada6c7fee [dynamo] Fix a bug by desugaring in-place ops on constants (#113117)
Summary:

Python allows users to write code like
```
x: 1
x += y
x += z
```

This code has well-defined semantics: because x is an immutable primitive, the first `+=` will actually re-bind x, it is equivalent to `x = x + y`.

The second in-place operation will either similarly desugar (if the result of `x + y` is itself immutable), or possibly result in "true" in-place operation.

Now, this is a problem for us because today, dynamo tries to both resolve constant variables to their literal values at compile time and also compile in a way that treats `operator.*` builtin functions consistently. This leads to a bug where code like
```
x: 1
x += y
```
actually gets compiled to
```
1 += y
```
which is both semantically meaningless and a syntax error.

A very simple fix that we've already used to fix the special case of `+=` is to detect this, treat it as an edge case, and desugar eagerly into `x = x + y`.

The problem with that fix is that it only patched `iadd`, but actually *all* of the in-place operators exhibit this behavior.

This commit proposes that we tackle all of the inplace opeartors supported by fx in the same way: eagerly remap the operation to an assignment when the left-side is actually an immutable constant.

**Alternatives?**

There might be some other fix possible that wouldn't produce a hardcoded remapping; I know that we generally don't like the growth of mappings and blocklists in dynamo.

I'm a little skeptical about a general solution though, because the bug is due precisely to Python's highly dynamic dispatching of inplace operations by type; since the fx graph has to be purely static, I suspect that we actually have to desugar this somewhere, because the dataflow is fundamentally different for true inplace operations on types that define `__iadd__`, etc vs the desugaring on primitives.

I'm open to other suggestions

Test Plan:

I verified that the code in
https://github.com/pytorch/pytorch/issues/112656
compiles with this fix, and the compiled functions produce the same outputs as the originals.

This needs unit tests, but I'd like to get feedback on the approach in the meantime.

Fixes #112656

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113117
Approved by: https://github.com/yanboliang
2023-11-10 00:22:55 +00:00
e6eab49e11 [dynamo] graph break on setattr requires_grad (#113163)
Main: `RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn`
This PR: graph breaks and eager applies the mutation, new tensors are tracked

Fixes https://github.com/pytorch/pytorch/issues/109505 (the original bug does not occur, but a new bug where the mutation isn't applied - because AOTAutograd is not `requires_grad` mutation aware - is mitigated)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113163
Approved by: https://github.com/bdhirsh
2023-11-09 13:13:29 +00:00
addb8e29cd Enable 2d + AC torch.compile (#112536)
This PR enables AC + torch.compile to work with FSDP + TP, the fix to
high order op path is that we need to check both tensor and tensor
subclass bases to make sourceless builder

NOTE: selective AC + 2D is still not working, need to fix this
separately

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112536
Approved by: https://github.com/yf225
2023-11-09 06:12:13 +00:00
2c4be77f02 Revert "[dynamo] Graph break on setattr(Tensor, "data", Tensor) (#113043)" (#113297)
This reverts commit ddfe5725342b0c0f707222879ca9dac305f97210.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113297
Approved by: https://github.com/PaliC
2023-11-09 00:26:21 +00:00
94d95a91a2 Revert "[dynamo] graph break on setattr requires_grad (#113163)"
This reverts commit d261687d5f56ac8148fab2567cf1fa6dd5264def.

Reverted https://github.com/pytorch/pytorch/pull/113163 on behalf of https://github.com/PaliC due to relevant tests are not running for this pr, however, this is fixed after landing https://github.com/pytorch/pytorch/pull/113297/ ([comment](https://github.com/pytorch/pytorch/pull/113163#issuecomment-1802967236))
2023-11-09 00:23:04 +00:00
d261687d5f [dynamo] graph break on setattr requires_grad (#113163)
Main: `RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn`
This PR: graph breaks and eager applies the mutation, new tensors are tracked

Fixes https://github.com/pytorch/pytorch/issues/109505 (the original bug does not occur, but a new bug where the mutation isn't applied - because AOTAutograd is not `requires_grad` mutation aware - is mitigated)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113163
Approved by: https://github.com/bdhirsh
2023-11-08 19:51:23 +00:00
3e4d14702a On grad access, check if grad has changed and update stored example grad as needed (#112811)
Fixes https://github.com/pytorch/pytorch/issues/112446

This is a doozy of a PR, there's a few important things to keep in mind here:

1) We MUST lift all tensors accessed via attrs to inputs, getattr is a no go in the graph, it violates the aot_autograd contract. Furthermore, aot_autograd does not know how to apply in-place ops to intermediary tensors that are attributes (aka from getattr) anyway. Views from ops are fine.

2) `.grad` access handling in dynamo peeks at the underlying value, the real tensor, because re-piping FakeTensors already made with this fake_mode through builder anew is a no go.

3) We have no proper mechanism for updating the hint / grapharg.example (the real value in (2) above) midway through trace

Therefore, what we need to do is reconcile the difference in grad stashed on grapharg.example. The easiest way to do this is lazily, upon .grad access, by reading the new value off the right fake tensors. We can then make a tensor using that data as a hint to VariableBuilder to make the right VariableTracker. Note that the example value used here (torch.zeros) in the PR, is a dummy value only used as a tracing hint, it does not leak out into real runtime code.

Alternatively, we could implement accumulate_grad_ in python...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112811
Approved by: https://github.com/jansel
2023-11-08 05:45:00 +00:00