Fixes#165447
On AOTAutogradCache load, the serialization function we pick is just lambda: self, because the object itself is an AOTAutogradCacheEntry. However, this isn't safe, because `wrap_post_compile` will make `self` unserializable, since it needs to load triton kernels and stuff!
So instead, on AOTAutogradCache load, we preserve the bytes that were used to load the object to begin with, and return that object on a call to serialize(). This effectively makes it so that we save a copy of the pre-hydrated artifact, without needing to do an eager copy until someone actually calls `serialize`.
Test Plan:
Run
```py
import torch
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(2, 4)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(4, 8)
def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))
device = "cuda"
m = M().to(device)
sample_inputs = (torch.randn(2, 2, device=device),)
eager_out = m(*sample_inputs)
with torch._dynamo.config.patch("enable_aot_compile", True):
compiled_fn_path = "./m.pt"
compiled_fn = torch.compile(
m,
fullgraph=True
).forward.aot_compile((sample_inputs, {}))
compiled_fn.save_compiled_function(compiled_fn_path)
torch._dynamo.reset()
with torch.compiler.set_stance("fail_on_recompile"):
with open(compiled_fn_path, "rb") as f:
loaded_fn = torch.compiler.load_compiled_function(f)
assert loaded_fn is not None
compiled_out = loaded_fn(m, *sample_inputs)
assert torch.allclose(eager_out, compiled_out)
```
twice, see that it succeeds.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165474
Approved by: https://github.com/yiming0416, https://github.com/zhxchen17
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:
(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: https://github.com/pytorch/pytorch/pull/164939)
(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164577
Approved by: https://github.com/ezyang
ghstack dependencies: #165372
The match for backward nodes might be in a different submod, so we should check all submod for potential matches.
In flex attention, this could happen if `mask_mod` has operations (such as index) that increase the seq_nr of the forward graph nodes. Then the backward flex_attention nodes cannot find a match in its own subgraph.
```
python test/functorch/test_aot_joint_with_descriptors.py -k preserve_annotate
```
Also tested on torchtitan joint_graph_runner branch. The flex_attention backward nodes are annotated now.
```
NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" LOG_RANK=0 TRAIN_FILE="torchtitan.train" TORCHFT_LIGHTHOUSE="http://localhost:29510" PYTORCH_ALLOC_CONF="expandable_segments:True" torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint="localhost:0" --local-ranks-filter 0 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/llama3/train_configs/debug_model.toml --model.name joint_graph_runner.llama3 --compile.enable --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165202
Approved by: https://github.com/SherlockNoMad
This PR introduces a way to compile a region of FX graph using `fx.traceback.annotate`.
### UX
1) In the user code, mark the region that you want to be compiled with inductor using `with fx_traceback.annotate({"compile_with_inductor": 0})`. As of now, we just rely on the string `compile_with_inductor` and ignore the integer. As the needs arise, we can update the logic.
Example
```
def fn(x, y):
sin = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 1
return torch.sin(add)
```
2) You have to instruct the compiler to use the annotations with `compile_fx_annotated_nodes_with_inductor` transformation. This is somewhat controversial, and a user might expect that just setting annotation is enough. But for now to control the blast radius, we need to explicitly do this. One such example is
```
# Set the fw and bw compiler of aot_autograd to `compile_fx_annotated_nodes_with_inductor`
def aot_eager_regional_inductor():
return aot_autograd(
fw_compiler=compile_fx_annotated_nodes_with_inductor,
bw_compiler=compile_fx_annotated_nodes_with_inductor,
)
```
3) Fixable in short-term - You have to wrap the user code in `torch.fx.traceback.preserve_node_meta` to ensure that annotations are propagated to the compiler. This is fixable, just need to make CI happy.
### Implementation
1) Relies on `CapabilityBasedPartitioner` to "scoop" out regions based on annotations, and then create subgraphs in the main graph.
2) Call `torch._inductor.standalone_compile` on these subgraphs, and jam the returned callable into the FX graph at the place of call_module
Resulting graph looks something like this - search for `torch__inductor_standalone_compile_inner`
Forward graph
```
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[10]", primals_2: "f32[10]"):
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
sin: "f32[10]" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
inner = torch__inductor_standalone_compile_inner(sin, primals_2)
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:68 in fn, code: add = mul + 1
getitem: "f32[10]" = inner[0]; inner = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
sin_1: "f32[10]" = torch.ops.aten.sin.default(getitem)
return (sin_1, primals_1, primals_2, sin, getitem)
```
Backward graph
```
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[10]", primals_2: "f32[10]", sin: "f32[10]", add: "f32[10]", tangents_1: "f32[10]"):
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
cos_1: "f32[10]" = torch.ops.aten.cos.default(primals_1); primals_1 = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
cos: "f32[10]" = torch.ops.aten.cos.default(add); add = None
mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
# No stacktrace found for following nodes
inner = torch__inductor_standalone_compile_inner(mul_1, sin, primals_2); mul_1 = sin = primals_2 = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:67 in fn, code: mul = sin * y
getitem: "f32[10]" = inner[0]
getitem_1: "f32[10]" = inner[1]; inner = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
mul_4: "f32[10]" = torch.ops.aten.mul.Tensor(getitem_1, cos_1); getitem_1 = cos_1 = None
return (mul_4, getitem)
```
### Some issue raised in the HOP meeting
1) CSE will not differentiate different meta custom nodes and do wrong thing.
2) SAC - The recomputed forward will be smaller than the forward. Will we compile a smaller region than?
3) What happens if you have a op in the middle which does not disturb the topology, is it still 1 subgraph?
4) What happens with the nesting of `fx_traceback.annotate`? Are there any ordering requirements?
5) What are we going to use the annotations for?
a) compile flex
b) streams
c) nn.Module info to organize MoE components for pipelining
d) PP stages
e) Rename graph nodes for more debugging
f) No nested regional compile
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164776
Approved by: https://github.com/SherlockNoMad
ghstack dependencies: #165188
This PR introduces a way to compile a region of FX graph using `fx.traceback.annotate`.
### UX
1) In the user code, mark the region that you want to be compiled with inductor using `with fx_traceback.annotate({"compile_with_inductor": 0})`. As of now, we just rely on the string `compile_with_inductor` and ignore the integer. As the needs arise, we can update the logic.
Example
```
def fn(x, y):
sin = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 1
return torch.sin(add)
```
2) You have to instruct the compiler to use the annotations with `compile_fx_annotated_nodes_with_inductor` transformation. This is somewhat controversial, and a user might expect that just setting annotation is enough. But for now to control the blast radius, we need to explicitly do this. One such example is
```
# Set the fw and bw compiler of aot_autograd to `compile_fx_annotated_nodes_with_inductor`
def aot_eager_regional_inductor():
return aot_autograd(
fw_compiler=compile_fx_annotated_nodes_with_inductor,
bw_compiler=compile_fx_annotated_nodes_with_inductor,
)
```
3) Fixable in short-term - You have to wrap the user code in `torch.fx.traceback.preserve_node_meta` to ensure that annotations are propagated to the compiler. This is fixable, just need to make CI happy.
### Implementation
1) Relies on `CapabilityBasedPartitioner` to "scoop" out regions based on annotations, and then create subgraphs in the main graph.
2) Call `torch._inductor.standalone_compile` on these subgraphs, and jam the returned callable into the FX graph at the place of call_module
Resulting graph looks something like this - search for `torch__inductor_standalone_compile_inner`
Forward graph
```
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[10]", primals_2: "f32[10]"):
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
sin: "f32[10]" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
inner = torch__inductor_standalone_compile_inner(sin, primals_2)
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:68 in fn, code: add = mul + 1
getitem: "f32[10]" = inner[0]; inner = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
sin_1: "f32[10]" = torch.ops.aten.sin.default(getitem)
return (sin_1, primals_1, primals_2, sin, getitem)
```
Backward graph
```
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[10]", primals_2: "f32[10]", sin: "f32[10]", add: "f32[10]", tangents_1: "f32[10]"):
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
cos_1: "f32[10]" = torch.ops.aten.cos.default(primals_1); primals_1 = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
cos: "f32[10]" = torch.ops.aten.cos.default(add); add = None
mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
# No stacktrace found for following nodes
inner = torch__inductor_standalone_compile_inner(mul_1, sin, primals_2); mul_1 = sin = primals_2 = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:67 in fn, code: mul = sin * y
getitem: "f32[10]" = inner[0]
getitem_1: "f32[10]" = inner[1]; inner = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
mul_4: "f32[10]" = torch.ops.aten.mul.Tensor(getitem_1, cos_1); getitem_1 = cos_1 = None
return (mul_4, getitem)
```
### Some issue raised in the HOP meeting
1) CSE will not differentiate different meta custom nodes and do wrong thing.
2) SAC - The recomputed forward will be smaller than the forward. Will we compile a smaller region than?
3) What happens if you have a op in the middle which does not disturb the topology, is it still 1 subgraph?
4) What happens with the nesting of `fx_traceback.annotate`? Are there any ordering requirements?
5) What are we going to use the annotations for?
a) compile flex
b) streams
c) nn.Module info to organize MoE components for pipelining
d) PP stages
e) Rename graph nodes for more debugging
f) No nested regional compile
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164776
Approved by: https://github.com/SherlockNoMad
Fixes#164814 - we update to include cases where we know symbolic expression is statically one. There are two errors here; first in graph capture, where a tensor with size 0 yet symbolic stride would attempt to keep the symbolic stride, resulting in a mismatch. The second is in inductor code gen, where we only checked in squeeze if size == 1, missing the case where a symbolic stride equals 1.
Also fixes#164924 (@bobrenjc93 for fuzzer finding an issue affecting users : )
### Test plan:
```
python test/dynamo/test_aot_autograd.py AotAutogradFallbackTests
```
Results in:
```
..
----------------------------------------------------------------------
Ran 49 tests in 45.622s
OK (expected failures=1)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164897
Approved by: https://github.com/laithsakka
Splits the training and inference paths for aot stage2 compile.
1. Split `aot_stage2_autograd` into `_aot_stage2a_partition`, `_aot_stage2b_fw_compile` and `_aot_stage2b_bw_compile`, and rest.
2. Split `aot_stage2_inference` into `_aot_stage2b_inference_compile` and rest.
I'm leaving these as functions with underscore names since the I/O interfaces and the exact boundaries of these splits are somewhat in the air.
Differential Revision: D84028203
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164808
Approved by: https://github.com/SherlockNoMad
This is follow-up of #165037. It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165142
Approved by: https://github.com/albanD
It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165037
Approved by: https://github.com/mlazos
In aot_stage2_autograd:
Before calling fw_compiler, we run pre_compile for the following wrappers:
* FakifiedOutWrapper
* FunctionalizedRngRuntimeWrapper
After, we run post_compile for the following wrappers:
* EffectTokensWrapper
* AOTDispatchSubclassWrapper
* FunctionalizedRngRuntimeWrapper
* FakifiedOutWrapper
In aot_stage2_inference:
Before calling inference compiler, we run pre_compile for the following wrappers (same as above):
* FakifiedOutWrapper
* FunctionalizedRngRuntimeWrapper
After, we run post_compile for the following wrappers (different than above):
* FunctionalizedRngRuntimeWrapper
* FakifiedOutWrapper
* EffectTokensWrapper
* AOTDispatchSubclassWrapper
This PR makes both do the post_compiles in the same order.
Differential Revision: D84213657
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165016
Approved by: https://github.com/zhxchen17, https://github.com/bdhirsh
# Propagate custom meta data to backward
Support propagating the user annotation tags to backward graph, by extending the `copy_fwd_metadata_to_bw_nodes` utils (recommended by @xmfan , thanks!).
Example annotation API (added in https://github.com/pytorch/pytorch/pull/163673):
```
class M(torch.nn.Module):
def forward(self, x):
with fx_traceback.annotate({"pp_stage": 0}):
with fx_traceback.annotate({"fdsp_bucket": 0}):
x = x + 1
x = x - 2
with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}):
x = x * 2
x = x / 3
return x
```
Assumptions (some inherited from https://github.com/pytorch/pytorch/pull/126573):
- I am trusting the seq_nr mapping introduced to aot_autograd nodes in https://github.com/pytorch/pytorch/pull/103129
- I am also trusting that the forward is single threaded, since seq_nr is thread local. If this isn't always true, we'll need to also plumb thread_id through the same machinery which is populating seq_nr.
- **(This is changed in this PR!) I assume all backward graph nodes has "is_backward" for 'partitioner_tag', and all other nodes are forward graph nodes**. If we don't run export before `aot_export_join_with_descriptors`, then none of the nodes has "nn_module_stack" in node meta. If we do run export first, then we don't need this change.
- I copy "custom" node meta from forward to backward graph nodes.
Question:
- Is it a good idea to copy all "custom" node meta? Or should we create a dedicated key in custom node meta to be copied? @SherlockNoMad
- Do we expect people to run export before using `aot_export_join_with_descriptors`?
- Can we assume the following for graph produced by `aot_export_join_with_descriptors`? "all backward graph nodes has "is_backward" for 'partitioner_tag', and all other nodes are forward graph nodes". Maybe this is a question for @ezyang
```
python test/functorch/test_aot_joint_with_descriptors.py -k test_preserve_
python test/export/test_export.py -k preserve_anno
python test/distributed/tensor/test_dtensor_export.py
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164174
Approved by: https://github.com/xmfan, https://github.com/SherlockNoMad
Summary: Minor refactor where we push some args in the aot joint with descriptors workflow that are not used in export stage to the compile stage where they are actually used.
Test Plan: existing tests should pass
Differential Revision: D83850316
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164584
Approved by: https://github.com/tugsbayasgalan
Changelog:
1. When we run into an operation we didn't proxy, we end up emitting fake constants. We error under a config and we disable the config for some internal users. The reason we want to error is this signals a coverage problem we need to address but at the same time, we don't wnat to be disruptive to already working flows.
2. Previous attribute mutation detection logic in non-strict didn't account for nested module structure. This fixes silent incorrectness issue of exporting esm and qwen in non-strict and some torchbench models like levit_128 and demucs.
3. Previous logic didn't work on the cases where we mutate a container attribute as the previous approach used to pytree over old and new attributes resulting in length mismatch. We gracefully handle this now.
Differential Revision: [D83673054](https://our.internmc.facebook.com/intern/diff/D83673054)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164372
Approved by: https://github.com/avikchaudhuri
This partially solve the issue https://github.com/pytorch/pytorch/issues/163641. We do not need to ban unbacked to unbacked replacement if all rhs symbols are inputs since we know those symbols are seen by the whole program.
This issue was found as i was tracing some vllm models with unbacked, namely Qwen/Qwen2-1.5B-Instruct it makes reasoning logic easier to do those replacements.
as for data dependent similar pattern, I am thinking to create a set of replacements that we apply only during static eval
instead of none. to make reasoning better.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163652
Approved by: https://github.com/bobrenjc93
Summary:
LSTM was not exportable with non-strict export as it failed at `_detect_attribute_assignment`
This is because the `_flat_weights` attribute in LSTM is a list of registered parameters and will be updated by the `_update_flat_weights` method in `forward`.
However, in `_detect_attribute_assignment`, we manually restore the state of the module by `mod.__dict__.update(snapshot)`. Therefore, it should be fine to turn the `ValueError` into a warning so that RNN models are exportable with non-strict export.
Added test to verify that there is no lifted tensor constant and no fake tensor leakage.
Test Plan: buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_export_rnn_variants_with_warning
Differential Revision: D83196971
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163809
Approved by: https://github.com/tugsbayasgalan
Summary: We observe a case then the fwd graph has duplicated return nodes, which will lead to errors due to fx renaming the node, thus we add poi info into the node name.
Test Plan:
### unit test
```
CUDA_VISIBLE_DEVICES=3 buck2 test mode/opt -m ovr_config//triton:beta -c fbcode.nvcc_arch=b200a -c fbcode.platform010_cuda_version=12.8 //caffe2/test/functorch:test_aotdispatch -- test_quantize_activation_duplicate_nodes
```
Buck UI: https://www.internalfb.com/buck2/de5eccc6-4064-4214-843d-70b8e3829afe
Test UI: https://www.internalfb.com/intern/testinfra/testrun/4503599937670844
Network: Up: 217KiB Down: 72KiB (reSessionID-73e5c269-4f4d-4a54-896a-79c077eea326)
Executing actions. Remaining 0/2 0.1s exec time total
Command: test. Finished 1 local
Time elapsed: 45.9s
Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0
### E2E
before
f798417700
after
Differential Revision: D82844100
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163364
Approved by: https://github.com/Yuzhen11
Summary:
This diff does a big refactor of PrecompileContext to make it considerably simpler: instead of being a CacheArtifactManager and managing a bunch of bytes, it simply stores two things: dynamo cache entries and backend cache entries. When asked, it stitches them together into PrecompileCacheEntries, which are stored by DynamoCache.
This structure then allows us to register DynamoCache to the regular Megacache API, instead of having two separate APIs that are confusing. It also lets us remove the autotune cache integration, since MegaCache API will automatically store autotune cache entries.
The intent here is that users who want to use caching precompile will simply be able to use torch.compiler.save_cache_artifacts as before, just with `torch.dynamo.config.caching_precompile` set to True. They can also directly interact with PrecompileContext if they wish to specifically only load Precompile entries, using PrecompileContext.create_cache_entries().
Saving single entries and such with DynamoCache still works normally.
Test Plan:
All existing unit tests pass.
Rollback Plan:
Differential Revision: D82380307
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162886
Approved by: https://github.com/zhxchen17
This PR refactors AOTAutograd slightly:
- It adds `simple_wraps` to various wrappers so that the reference to inner functions is stored in the output of AOTAutograd.
- It saves a `serialize()` method on the result of `aot_stage2`, in the event of an eager backward compile.
I discussed the lazy backward case with @bdhirsh, and we agreed that serialization in that case would probably use a different, more AOT API anyway, so we do not implement a serialize function for the lazy backward case. AOT precompile, at least initially, will always eagerly compile the backward.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162527
Approved by: https://github.com/zhxchen17
ghstack dependencies: #162171