Commit Graph

110 Commits

Author SHA1 Message Date
086dec3235 Pyrefly suppressions 6/n (#164877)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Almost there!

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:

INFO 0 errors (5,064 ignored)

Only four directories left to enable

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164877
Approved by: https://github.com/oulgen
2025-10-08 02:30:57 +00:00
4ab847bbc7 Pyrefly suppressions 4/n (#164615)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: uncomment lines in the pyrefly.toml file
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/356645cf8cfe33123d9a27f23b30f7b1

after:

0 errors (2,753 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164615
Approved by: https://github.com/oulgen
2025-10-06 16:14:36 +00:00
a43c4c3972 [5/N] Apply ruff UP035 rule (#164423)
Continued code migration to enable ruff `UP035`. Most changes are about moving `Callable` from `typing` to `from collections.abc`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164423
Approved by: https://github.com/ezyang
2025-10-02 07:31:11 +00:00
8239ba4087 Fix various bugs in subclass input in export (#163770)
This adds basic support for subclass inputs in export (specifically for non-strict). I had to make fakify little more complicated which risks further divergence from dynamo fakification. But dynamo one is so complex, so i feel it is better to do this way. Also improved fake mode detection logic to recursively look into subclass inner tensors.

Differential Revision: [D83156489](https://our.internmc.facebook.com/intern/diff/D83156489)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163770
Approved by: https://github.com/avikchaudhuri
2025-09-28 18:03:32 +00:00
3c45af079a kill allow_complex_guards_as_runtime_asserts (#161794)
Summary:
[reland]
Since `allow_complex_guards_as_runtime_asserts` is now sync'd with `prefer_deferred_runtime_asserts_over_guards`, we can kill the former (especially since it was a export-only concept).

Test Plan:
updated tests

Rollback Plan:

Differential Revision: D81334984

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161794
Approved by: https://github.com/zhxchen17
2025-09-04 00:17:01 +00:00
47742081c9 Revert "kill allow_complex_guards_as_runtime_asserts (#160198)"
This reverts commit 69d91b94ba5366f4444d8cb8fd3dab4de4f04d3d.

Reverted https://github.com/pytorch/pytorch/pull/160198 on behalf of https://github.com/jeffdaily due to let's revert again instead of waiting for forward fix, see earlier comments ([comment](https://github.com/pytorch/pytorch/pull/160198#issuecomment-3235165462))
2025-08-28 22:50:37 +00:00
69d91b94ba kill allow_complex_guards_as_runtime_asserts (#160198)
Summary: Since `allow_complex_guards_as_runtime_asserts` is now sync'd with `prefer_deferred_runtime_asserts_over_guards`, we can kill the former (especially since it was a export-only concept).

Test Plan:
updated tests

Rollback Plan:

Differential Revision: D79903317

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160198
Approved by: https://github.com/ezyang
2025-08-28 19:36:19 +00:00
a8270dd124 Revert "kill allow_complex_guards_as_runtime_asserts (#160198)"
This reverts commit 196232bb935cb346f143d5c39e9a73c44121a033.

Reverted https://github.com/pytorch/pytorch/pull/160198 on behalf of https://github.com/atalman due to dynamo/test_activation_checkpointing.py::ActivationCheckpointingViaTagsTestsCUDA::test_compile_selective_checkpoint_triton_kernel_cuda [GH job link](https://github.com/pytorch/pytorch/actions/runs/17289619543/job/49074475338) [HUD commit link](196232bb93) ([comment](https://github.com/pytorch/pytorch/pull/160198#issuecomment-3234013520))
2025-08-28 15:40:37 +00:00
196232bb93 kill allow_complex_guards_as_runtime_asserts (#160198)
Summary: Since `allow_complex_guards_as_runtime_asserts` is now sync'd with `prefer_deferred_runtime_asserts_over_guards`, we can kill the former (especially since it was a export-only concept).

Test Plan:
updated tests

Rollback Plan:

Differential Revision: D79903317

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160198
Approved by: https://github.com/ezyang
2025-08-28 07:59:29 +00:00
12c0cf3fab switch prefer_deferred_runtime_asserts_over_guards in export (#160111)
Summary:
In preparation for checking shape guards in export, this PR effectively switches `prefer_deferred_runtime_asserts_over_guards` to `False`, matching Dynamo.

Actually that's a lie: we switch it to `allow_complex_guards_as_runtime_asserts`, which is `False` by default but can be controlled via an internally API to be `True`. This makes the two flags synchronized, so we should be able to kill `allow_complex_guards_as_runtime_asserts` at this point.

Test Plan:
updated tests

Rollback Plan:

Differential Revision: D79734206

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160111
Approved by: https://github.com/tugsbayasgalan
2025-08-27 22:51:10 +00:00
2b1ae29960 [Dynamo][Better Engineering] Add typing annotations to guard and source (#158397) (#159491)
Summary:
X-link: https://github.com/pytorch/executorch/pull/12986

As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo

This PR adds strict typing support to a critical set of files for dynamo, `source.py` and the base `_guards.py`

Running
```
mypy torch/_dynamo/source.py torch/_guards.py --linecount-report /tmp/coverage_log
```

| -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered |
| -------- | ------- | -------- | ------- | ------- | ------- | ------- |
| Main  |  1227 | 2208 | 55.57% | 207 | 362 | 57.18% |
| This PR | 2217 | 2217 | 100.00% | 362 | 362 | 100.00% |
| Delta    | +990 | +9 | +44.43% | +155 | 0 | +42.82% |

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 jerryzh168 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

Test Plan:
Imported from GitHub, without a `Test Plan:` line.

Rollback Plan:

Reviewed By: JacobSzwejbka, yangw-dev

Differential Revision: D79199389

Pulled By: Lucaskabela

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159491
Approved by: https://github.com/anijain2305, https://github.com/yangw-dev
2025-07-30 22:57:50 +00:00
d987a6f7f0 Revert "[Dynamo][Better Engineering] Add typing annotations to guard and source (#158397)"
This reverts commit abcb24f4de11f8fedf2c2c9ff53b6092ef42306d.

Reverted https://github.com/pytorch/pytorch/pull/158397 on behalf of https://github.com/yangw-dev due to Suggested to fix failing internal signals on D78911890 ([comment](https://github.com/pytorch/pytorch/pull/158397#issuecomment-3133823766))
2025-07-29 19:49:40 +00:00
abcb24f4de [Dynamo][Better Engineering] Add typing annotations to guard and source (#158397)
As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo

This PR adds strict typing support to a critical set of files for dynamo, `source.py` and the base `_guards.py`

Running
```
mypy torch/_dynamo/source.py torch/_guards.py --linecount-report /tmp/coverage_log
```

| -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered |
| -------- | ------- | -------- | ------- | ------- | ------- | ------- |
| Main  |  1227 | 2208 | 55.57% | 207 | 362 | 57.18% |
| This PR | 2217 | 2217 | 100.00% | 362 | 362 | 100.00% |
| Delta    | +990 | +9 | +44.43% | +155 | 0 | +42.82% |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158397
Approved by: https://github.com/anijain2305
2025-07-24 15:55:18 +00:00
9222552572 [non-strict export] uncovered cases of select and slice (#157821)
Summary:
`None` and `Ellipsis` in multi-dimensional indexing was previously not covered.

Moreover, we introduce a small optimization for `slice(None)` and a passthrough when symints do not appear in the indexing.

The remaining case is where indexing is by tensor, which is fairly complicated; we passthrough in that case.

Test Plan:
added tests

Rollback Plan:

Differential Revision: D77943929

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157821
Approved by: https://github.com/pianpwk
2025-07-10 05:48:12 +00:00
5dfd8a9c7a Remove is_jit_trace option (#157387)
Summary: Title

Test Plan:
CI

Rollback Plan:

Differential Revision: D77319249

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157387
Approved by: https://github.com/pianpwk
2025-07-03 09:20:27 +00:00
162ca185ff [BE][PYFMT] migrate PYFMT for torch/_[a-h]*/ to ruff format (#144551)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144551
Approved by: https://github.com/ezyang
ghstack dependencies: #148186
2025-06-25 06:16:06 +00:00
463fe36532 fix error message on specialization with Dim.DYNAMIC (#155738)
Previously specialization error messages would render sources that were pretty far from source-code names. E.g., given args named `x, y, zs`, the source for `y.size()[0]` would be rendered as `args[0][1].size()[0]`.

This is because we created artificial local names following `(args, kwargs)` structure instead of reusing signatures. This PR fixes that situation.

Basically we map prefixes of key paths that correspond to original arg names to root sources corresponding to those names; the rest of the key paths hang from these root sources.

Differential Revision: [D76461391](https://our.internmc.facebook.com/intern/diff/D76461391/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155738
Approved by: https://github.com/bobrenjc93
2025-06-13 10:33:46 +00:00
3027051590 [export] avoid float/bool specialization for scalar tensor construction (#154661)
Fixes #153411

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154661
Approved by: https://github.com/angelayi
2025-05-30 17:18:21 +00:00
9c06dff1ce [multigraph] use specializations in compile_and_call_fx_graph (#153449)
The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM which does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs.

There's really two parts of this work:

**The frontend changes:**
1) we introduce an optional kwarg `specialize_on` to mark_{dynamic,unbacked} that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc.

**The backend changes (this PR):**
1) We capture the backend_specialization specified in the mark_{dynamic,unbacked} API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py`
2) After we are done dynamo tracing, we will lazily (more on this later) invoke `call_user_compiler` up to N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. We do this by using a context manager to patch in specialization specific axioms into the ShapeEnv before invoking the user compiler.
3) When we have specializations, we install a lazy specialized dispatch function that checks each specialization and dispatches to the first one that matches. Instead of doing all of the specialization compiles up front, we do the compiles lazily. The first time a specialization is invoked, we will do the compilation and save it in a cache so subsequent invocations are fast. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards.

I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:

![495269647_576053105510082_9189856138964956774_n](https://github.com/user-attachments/assets/66030fed-d62e-4d87-940f-aa13c99b1a73)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153449
Approved by: https://github.com/zou3519
ghstack dependencies: #153433
2025-05-30 03:19:49 +00:00
8ac82c3e72 [export] support functools.partial forward (non-strict) (#153408)
Fixes #153086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153408
Approved by: https://github.com/tugsbayasgalan
2025-05-13 23:30:13 +00:00
47972f9092 [export] warn when Dim.AUTO 0/1 specializes (#151827)
Fixes #151582

example warning for Dim.AUTO:
```
torch/_export/non_strict_utils.py:499] dimension inputs['x'].shape[1] 0/1 specialized; Dim.AUTO was specified along with a sample input with hint = 1.
```

example error when Dim.DYNAMIC specializes:
```
- Received user-specified dim hint Dim.DYNAMIC(min=None, max=None), but export 0/1 specialized due to hint of 0 for dimension inputs['x'].shape[0].
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151827
Approved by: https://github.com/angelayi
2025-05-01 06:00:51 +00:00
6a1b820255 [export] Enable symint inputs for AdditionalInputs and ShapesCollection (#151842)
With `AdditionalInputs`, the behavior is the same as with tensors:
```python
class M(torch.nn.Module):
    def forward(self, x, y):
        return x + y

additional_inputs = torch.export.AdditionalInputs()
additional_inputs.add((5, 5))
additional_inputs.add((3, 5))
additional_inputs.add((5, 4))
ep = torch.export.export(
    M(), (6, 7), dynamic_shapes=additional_inputs, strict=False
)
```

With `ShapesCollection`, we now need to wrap integer inputs as `_IntWrapper` so that we can have a unique identifier for each integer input.
```python
class M(torch.nn.Module):
    def forward(self, x, y):
        return x + y

from torch.export.dynamic_shapes import _IntWrapper

args = (_IntWrapper(5), _IntWrapper(5))
# Or we can do `args = pytree.tree_map_only(int, lambda a: _IntWrapper(a), orig_args)`
shapes_collection = torch.export.ShapesCollection()
shapes_collection[args[0]] = Dim.DYNAMIC
shapes_collection[args[1]] = Dim.DYNAMIC
ep = torch.export.export(
    M(), args, dynamic_shapes=shapes_collection, strict=False
)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151842
Approved by: https://github.com/pianpwk
2025-04-22 22:29:18 +00:00
2c27597d6a Infra for handling builtin ops (min, max, math.pow) (#151348)
Reapply of https://github.com/pytorch/pytorch/pull/150003

Differential Revision: [D73050801](https://our.internmc.facebook.com/intern/diff/D73050801/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151348
Approved by: https://github.com/zhxchen17
ghstack dependencies: #151347
2025-04-22 17:20:09 +00:00
e6969c1bd8 [export] Symint support (nonstrict, Dim.DYNAMIC) (#150198)
Fixes https://github.com/pytorch/pytorch/issues/113682 only in the non-strict export case. Also we only support Dim.DYNAMIC/AUTO, not named-Dims

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150198
Approved by: https://github.com/pianpwk
2025-04-10 15:06:23 +00:00
f8b53f4a75 [export] raise when Dim.DYNAMIC 0/1 specializes (#150716)
Previously we didn't catch this, mark_dynamic() just doesn't allocate a symbol for it

Differential Revision: D72486930

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150716
Approved by: https://github.com/angelayi
2025-04-07 18:58:42 +00:00
1017927c83 multidimensional slicing (#150104)
Differential Revision: D71962884

Fixes #150057

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150104
Approved by: https://github.com/angelayi
2025-04-02 20:57:16 +00:00
925fd4aa2e [export] min/max ranges for dim hints (#149590)
Differential Revision: D71522032

Adds min/max ranges to Dim.AUTO/DYNAMIC/STATIC, so users can do `Dim.AUTO(min=2, max=2048)`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149590
Approved by: https://github.com/tugsbayasgalan
2025-03-31 21:32:20 +00:00
21bcbbfb5e fix range constraints for expr (#150103)
During tracing it is possible for a `s1: VR[2, inf]` to be replaced by a `s0: VR[3, inf]` (note smaller range) by the shape env. But after export, unfortunately we'd previously record `range_constraints[s0] = VR[2, inf]` (note larger range), which is incorrect.

This is because we'd map `s1.node.expr` (`s0`) to the `var_to_range` of `s1.node._expr` (`s1`) when creating `range_constraints`. The comment surrounding this code suggests this predated `bound_sympy`, but now we can do better.

For users, this means that when using `Dim.DYNAMIC` previously they wouldn't get input constraints checked sufficiently, now they do (shifting errors early).

Differential Revision: D71962694

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150103
Approved by: https://github.com/zhxchen17
2025-03-27 22:11:39 +00:00
46dd226702 Fakify torchbind objects in compile_fx and add tests for SigridTransformsInstanceTorchBind (#149529)
Summary:
We need to properly fakify torchbind objects, including the ones in graph module attributes, so the resgitered fake implementation works properly.

- _fakify_script_objects in `compile_fx`
- Allow fake torchbind objects in `torchbind_constants`

Remove `node.meta["unbacked_bindings"]` for `aot_compile` in `compile_fx`. Otherwise `ShapeProp` will fail when trying to resolve the `unbacked_bindings` of `with_effect` tokens.

Update `sigrid_transforms_test` to use the latest `torch._inductor.aot_compile` API.

Add a test for `Fakify torchbind objects in compile_fx and add tests for SigridTransformsInstanceTorchBind` in `e2e_test`.

Test Plan:
```
buck run //caffe2/torch/fb/sparsenn:sigrid_test -- -r test_transform_torch_bind

buck run //sigmoid/inference/test:e2e_test_cpu -- -r SigridTransforms

buck2 run mode/dev-nosan sigmoid/inference/ts_migration:pt2i_readiness_main -- --model_id 545017754 --test_suite ads_all --mode test_preproc

```

Differential Revision: D70013257

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149529
Approved by: https://github.com/angelayi
2025-03-21 18:58:28 +00:00
f47aa08130 [export] Support python assertion with symints. (#149444)
Summary: This diff ports some technique from torch.fx symbolic trace to trace through Python asserts when we run into data dependent symbolic shape assertions, so that we can achieve the same effect as torch dynamo to automatically turn assert into torch.check()s.

Test Plan: buck test mode/opt caffe2/test:test_export -- -r test_python_asserts_with_sym_int
Differential Revision: D71425360

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149444
Approved by: https://github.com/tugsbayasgalan
2025-03-20 23:07:45 +00:00
9a184b1074 Monkeypatch fake mode so it errors on invalid custom ops (#149410)
Internal version: [D71294776](https://www.internalfb.com/diff/D71294776)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149410
Approved by: https://github.com/gmagogsfm
2025-03-20 04:50:57 +00:00
6a6de0e09d better error message (#147532)
Differential Revision: [D69939736](https://our.internmc.facebook.com/intern/diff/D69939736)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147532
Approved by: https://github.com/avikchaudhuri, https://github.com/zou3519
2025-02-21 17:08:47 +00:00
77aa602871 [torchbind] Differentiate ScriptModule and ScriptObject with qualified name (#147399)
Summary:
This pr add a _is_script_object method to differentiate scriptModule and scriptObject, where the formal inherits from ScriptObject in C++ so they both passes the isinstance(obj, torch.ScriptObject) check.

The qualified name of ScriptObject (i.e. custom class) would starts with "__torch__.torch.classes", this has been a widely used assumption for dealing with custom class across our code base.

Test Plan: Add new test.

Differential Revision: D69685316

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147399
Approved by: https://github.com/yushangdi
2025-02-20 04:57:57 +00:00
24738768a8 more dist ops in non strict (#147417)
Summary: Previously we added support for `all_reduce` to non strict. This PR extends this support to other non-functional collectives that are remapped in Dynamo: `all_gather`, `all_gather_into_tensor`, `all_to_all_single`, `reduce_scatter_tensor`.

Test Plan: added unit tests

Differential Revision: D69813991

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147417
Approved by: https://github.com/angelayi
2025-02-19 21:29:16 +00:00
4ab967c44d all reduce non strict (#147133)
Summary:
Some distributed collectives like `all_reduce` have special handling in Dynamo, where they are mapped to functional collectives. Non-strict was previously blind to such mappings, which means using them would fail to trace. Here we show how intercepting them in non-strict's torch function mode can mimic this remapping logic. More ops to follow.

Side note: a recently added distributed test was in the wrong place, making the expected failures for non-strict not fire because we weren't actually generating those tests to begin with! Now fixed.

Test Plan: moved and updated test

Differential Revision: D69607140

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147133
Approved by: https://github.com/tugsbayasgalan
2025-02-15 19:37:08 +00:00
97d4d3c40a PEP585 update - torch/_export (#145138)
See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145138
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #145154
2025-01-19 18:48:35 +00:00
53256edff9 [export] Support module inputs for non strict mode. (#143925)
Summary:
Add experimental support for torch.nn.Module as input types.

Before this change, we don't support module inputs but recently we saw some interesting use cases like gpt-fast https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py#L68 where we directly pass in a module input for different variants of the same models.

Since we don't really care about non-param or non-buffer states in non strict mode, we don't care about those either and pretend they are like plain constants during tracing. We treat any module input like a nested container of tensor, and each time we will automatically register a pytree handler for these module types to flatten its state dict into a group of tensors. We will just inline any module method call during tracing like we did for `self` module in export_for_training. This will make input modules' behavior very similar to the training module in typical case, except that we don't record the inputs as parameter or buffers but rather just plain user inputs.

Test Plan: buck run mode/opt caffe2/test:test_export -- -r test_module_input

Differential Revision: D67680827

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143925
Approved by: https://github.com/tugsbayasgalan
2025-01-16 17:30:36 +00:00
d75ffccd0a Migrate from Tuple -> tuple in torch/_export (#144262)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144262
Approved by: https://github.com/avikchaudhuri
2025-01-06 22:20:26 +00:00
de484134e4 support slicing with symints in non-strict (#143217)
Differential Revision: [D67215745](https://our.internmc.facebook.com/intern/diff/D67215745/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143217
Approved by: https://github.com/tugsbayasgalan
2024-12-14 10:27:45 +00:00
51045e6251 make DimHints compatible with Dims (#138490)
Previously we'd been raising UserErrors when `Dim()` and DimHints (`Dim.AUTO/Dim.DYNAMIC`) were both specified in `dynamic_shapes`, this PR stops that, and uses `Dim()` objects to guide DimHints.

The key to this was making the `EqualityConstraint` class happy when it checks that inferred equivalence relations were specified in the original `dynamic_shapes` spec, and this introduces a `RelaxedConstraint` object to mark the hinted dimensions, so equality checks between `RelaxedConstraints` and other constraints are treated as valid.

Current behavior is that:
```
class Foo(torch.nn.Module):
    def forward(self, x, y):
        return x - y

inputs = (torch.randn(4, 4), torch.randn(4, 4))
shapes = {
    "x": (Dim.AUTO, Dim("d1", min=3)),
    "y": (Dim("d0", max=8), Dim.DYNAMIC),
}
ep = export(Foo(), inputs, dynamic_shapes=shapes)
```

The dimensions marked `AUTO` and `DYNAMIC` will have max & min ranges of 8 & 3 respectively. Note that inferred equality between `Dim()` objects & `Dim.STATIC` will still raise errors - `Dim()` suggests not specializing to a constant.

Differential Revision: D64636101

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138490
Approved by: https://github.com/avikchaudhuri
2024-10-22 07:43:48 +00:00
6dcd773c57 [export] clean up dynamic markers from tensors (#137230)
Summary:
When we handle dynamic shapes markers like `Dim.AUTO, Dim.DYNAMIC`, we use dynamo decorators, attaching set attributes to the export input tensors, e.g. `x._dynamo_dynamic_indices = set()`.

I thought this was fine, since it's done all the time with torch.compile, but it breaks some PT2Inference tests, specifically because unpickling a set attribute isn't possible with the C++ torch::jit::pickle_load call.

We've agreed that the PT2Inference side will clone sample inputs & pickle the original inputs to be safe, but this still establishes a nice invariant that user-facing decorators are both ignored & cleaned out in the lifecycle of an export call.

Test Plan: test_export

Differential Revision: D63773534

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137230
Approved by: https://github.com/avikchaudhuri
2024-10-04 06:50:45 +00:00
cc2a66c55e [export] hook up mark_dynamic to export Dims (#137029)
Adds Dim.DYNAMIC which calls torch._dynamo.mark_dynamic() in the backend. Similar to Dim.AUTO in that it does automatic inference for ranges & relations, but errors out for specializations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137029
Approved by: https://github.com/avikchaudhuri
2024-10-01 17:05:09 +00:00
6bd9d37266 Remove allow-untyped-defs from torch.fx.experimental.symbolic_shapes (#137019)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137019
Approved by: https://github.com/Skylion007
ghstack dependencies: #136934, #136935, #136972
2024-10-01 13:22:10 +00:00
6075f566cc [export] simplify automatic dynamic shapes processing (#136591)
Removing `_transform_shapes_for_default_dynamic` and `assume_static_by_default=False` as added in https://github.com/pytorch/pytorch/pull/133620.

This reverts back to `assume_static_by_default=True` with the use of dynamo decorators (e.g. `maybe_mark_dynamic, mark_static`, instead) for handling Dim.AUTO & Dim.STATIC instead. This is easier to maintain, as it doesn't requiring reasoning about "inverting" the dynamic_shapes specs, and also opens up usage of other decorators (`mark_dynamic, mark_unbacked`).

On the user side this change has no effect, but internally this means dynamic behavior is determined only by the `dynamic_shapes` specs (ignoring user-side input decorators following https://github.com/pytorch/pytorch/pull/135536), but transferring this information for _DimHints via decorators, for Dynamo/non-strict to create symbolic_contexts accordingly, e.g. 7c6d543a5b/torch/_dynamo/variables/builder.py (L2646-L2666)

One caveat is we don't raise errors for dynamic decorators on the user side, since we don't know if they're from user markings, or from re-exporting with inputs we've previously marked.

Differential Revision: D63358628

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136591
Approved by: https://github.com/avikchaudhuri
2024-09-27 18:28:51 +00:00
5c5c33ac32 [Dynamo] Trace torch function modes entered outside of torch.compile (#133137)
This PR adds initial tracing for torch function modes.

Details:
In essence, this adds tracing into the torch function of modes entered outside of the torch.compile call.
This does not yet support tracing enter/exit of a torch function mode/ tracing set_default_device properly using the new mode infra (this will be a very good stress test for modes). I am adding more PRs to this stack to support these. The overall plan is to support tracing enter/exit and handling graph breaks like we do other torch.* context managers.

Previously landed:
https://github.com/pytorch/pytorch/pull/133135
https://github.com/pytorch/pytorch/pull/133136
https://github.com/pytorch/pytorch/pull/133134
https://github.com/pytorch/pytorch/pull/133133
https://github.com/pytorch/pytorch/pull/133132
https://github.com/pytorch/pytorch/pull/133131
https://github.com/pytorch/pytorch/pull/133729
https://github.com/pytorch/pytorch/pull/133130

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133137
Approved by: https://github.com/jansel, https://github.com/zou3519
ghstack dependencies: #134732
2024-09-14 18:52:22 +00:00
8c8a3086a7 Revert "[Dynamo] Trace torch function modes entered outside of torch.compile (#133137)"
This reverts commit 4528777e034b157a8329d1879daf52290eea199a.

Reverted https://github.com/pytorch/pytorch/pull/133137 on behalf of https://github.com/mlazos due to broke python test/quantization/pt2e/test_numeric_debugger.py TestNumericDebugger.test_re_export_preserve_handle modified yesterday ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2350937008))
2024-09-14 10:02:55 +00:00
4528777e03 [Dynamo] Trace torch function modes entered outside of torch.compile (#133137)
This PR adds initial tracing for torch function modes.

Details:
In essence, this adds tracing into the torch function of modes entered outside of the torch.compile call.
This does not yet support tracing enter/exit of a torch function mode/ tracing set_default_device properly using the new mode infra (this will be a very good stress test for modes). I am adding more PRs to this stack to support these. The overall plan is to support tracing enter/exit and handling graph breaks like we do other torch.* context managers.

Previously landed:
https://github.com/pytorch/pytorch/pull/133135
https://github.com/pytorch/pytorch/pull/133136
https://github.com/pytorch/pytorch/pull/133134
https://github.com/pytorch/pytorch/pull/133133
https://github.com/pytorch/pytorch/pull/133132
https://github.com/pytorch/pytorch/pull/133131
https://github.com/pytorch/pytorch/pull/133729
https://github.com/pytorch/pytorch/pull/133130

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133137
Approved by: https://github.com/jansel, https://github.com/zou3519
ghstack dependencies: #134732
2024-09-14 02:40:43 +00:00
eb7dd91dd1 Revert "[Dynamo] Trace torch function modes entered outside of torch.compile (#133137)"
This reverts commit fafdd588f27e1d56090c6d260d0382c255eaf9eb.

Reverted https://github.com/pytorch/pytorch/pull/133137 on behalf of https://github.com/albanD due to Broke tests on main ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2348886378))
2024-09-13 12:52:58 +00:00
fafdd588f2 [Dynamo] Trace torch function modes entered outside of torch.compile (#133137)
This PR adds initial tracing for torch function modes.

Details:
In essence, this adds tracing into the torch function of modes entered outside of the torch.compile call.
This does not yet support tracing enter/exit of a torch function mode/ tracing set_default_device properly using the new mode infra (this will be a very good stress test for modes). I am adding more PRs to this stack to support these. The overall plan is to support tracing enter/exit and handling graph breaks like we do other torch.* context managers.

Previously landed:
https://github.com/pytorch/pytorch/pull/133135
https://github.com/pytorch/pytorch/pull/133136
https://github.com/pytorch/pytorch/pull/133134
https://github.com/pytorch/pytorch/pull/133133
https://github.com/pytorch/pytorch/pull/133132
https://github.com/pytorch/pytorch/pull/133131
https://github.com/pytorch/pytorch/pull/133729
https://github.com/pytorch/pytorch/pull/133130

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133137
Approved by: https://github.com/jansel, https://github.com/zou3519
ghstack dependencies: #134732
2024-09-13 08:41:00 +00:00
183c32fd3b Revert "[Dynamo] Trace torch function modes entered outside of torch.compile (#133137)"
This reverts commit 0d15122092c27fec1143b800bab7c996d126b547.

Reverted https://github.com/pytorch/pytorch/pull/133137 on behalf of https://github.com/clee2000 due to something in this stack broke functorch/test_control_flow.py::TestControlFlow::test_scan_simple_graph [GH job link](https://github.com/pytorch/pytorch/actions/runs/10804912306/job/29980571390) [HUD commit link](444b52ff40), newly added test yesterday ([comment](https://github.com/pytorch/pytorch/pull/133137#issuecomment-2344054339))
2024-09-11 15:57:00 +00:00