This PR introduces a way to compile a region of FX graph using `fx.traceback.annotate`.
### UX
1) In the user code, mark the region that you want to be compiled with inductor using `with fx_traceback.annotate({"compile_with_inductor": 0})`. As of now, we just rely on the string `compile_with_inductor` and ignore the integer. As the needs arise, we can update the logic.
Example
```
def fn(x, y):
sin = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 1
return torch.sin(add)
```
2) You have to instruct the compiler to use the annotations with `compile_fx_annotated_nodes_with_inductor` transformation. This is somewhat controversial, and a user might expect that just setting annotation is enough. But for now to control the blast radius, we need to explicitly do this. One such example is
```
# Set the fw and bw compiler of aot_autograd to `compile_fx_annotated_nodes_with_inductor`
def aot_eager_regional_inductor():
return aot_autograd(
fw_compiler=compile_fx_annotated_nodes_with_inductor,
bw_compiler=compile_fx_annotated_nodes_with_inductor,
)
```
3) Fixable in short-term - You have to wrap the user code in `torch.fx.traceback.preserve_node_meta` to ensure that annotations are propagated to the compiler. This is fixable, just need to make CI happy.
### Implementation
1) Relies on `CapabilityBasedPartitioner` to "scoop" out regions based on annotations, and then create subgraphs in the main graph.
2) Call `torch._inductor.standalone_compile` on these subgraphs, and jam the returned callable into the FX graph at the place of call_module
Resulting graph looks something like this - search for `torch__inductor_standalone_compile_inner`
Forward graph
```
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[10]", primals_2: "f32[10]"):
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
sin: "f32[10]" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
inner = torch__inductor_standalone_compile_inner(sin, primals_2)
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:68 in fn, code: add = mul + 1
getitem: "f32[10]" = inner[0]; inner = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
sin_1: "f32[10]" = torch.ops.aten.sin.default(getitem)
return (sin_1, primals_1, primals_2, sin, getitem)
```
Backward graph
```
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[10]", primals_2: "f32[10]", sin: "f32[10]", add: "f32[10]", tangents_1: "f32[10]"):
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
cos_1: "f32[10]" = torch.ops.aten.cos.default(primals_1); primals_1 = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
cos: "f32[10]" = torch.ops.aten.cos.default(add); add = None
mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
# No stacktrace found for following nodes
inner = torch__inductor_standalone_compile_inner(mul_1, sin, primals_2); mul_1 = sin = primals_2 = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:67 in fn, code: mul = sin * y
getitem: "f32[10]" = inner[0]
getitem_1: "f32[10]" = inner[1]; inner = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
mul_4: "f32[10]" = torch.ops.aten.mul.Tensor(getitem_1, cos_1); getitem_1 = cos_1 = None
return (mul_4, getitem)
```
### Some issue raised in the HOP meeting
1) CSE will not differentiate different meta custom nodes and do wrong thing.
2) SAC - The recomputed forward will be smaller than the forward. Will we compile a smaller region than?
3) What happens if you have a op in the middle which does not disturb the topology, is it still 1 subgraph?
4) What happens with the nesting of `fx_traceback.annotate`? Are there any ordering requirements?
5) What are we going to use the annotations for?
a) compile flex
b) streams
c) nn.Module info to organize MoE components for pipelining
d) PP stages
e) Rename graph nodes for more debugging
f) No nested regional compile
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164776
Approved by: https://github.com/SherlockNoMad
ghstack dependencies: #165188
This PR introduces a way to compile a region of FX graph using `fx.traceback.annotate`.
### UX
1) In the user code, mark the region that you want to be compiled with inductor using `with fx_traceback.annotate({"compile_with_inductor": 0})`. As of now, we just rely on the string `compile_with_inductor` and ignore the integer. As the needs arise, we can update the logic.
Example
```
def fn(x, y):
sin = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 1
return torch.sin(add)
```
2) You have to instruct the compiler to use the annotations with `compile_fx_annotated_nodes_with_inductor` transformation. This is somewhat controversial, and a user might expect that just setting annotation is enough. But for now to control the blast radius, we need to explicitly do this. One such example is
```
# Set the fw and bw compiler of aot_autograd to `compile_fx_annotated_nodes_with_inductor`
def aot_eager_regional_inductor():
return aot_autograd(
fw_compiler=compile_fx_annotated_nodes_with_inductor,
bw_compiler=compile_fx_annotated_nodes_with_inductor,
)
```
3) Fixable in short-term - You have to wrap the user code in `torch.fx.traceback.preserve_node_meta` to ensure that annotations are propagated to the compiler. This is fixable, just need to make CI happy.
### Implementation
1) Relies on `CapabilityBasedPartitioner` to "scoop" out regions based on annotations, and then create subgraphs in the main graph.
2) Call `torch._inductor.standalone_compile` on these subgraphs, and jam the returned callable into the FX graph at the place of call_module
Resulting graph looks something like this - search for `torch__inductor_standalone_compile_inner`
Forward graph
```
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[10]", primals_2: "f32[10]"):
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
sin: "f32[10]" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
inner = torch__inductor_standalone_compile_inner(sin, primals_2)
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:68 in fn, code: add = mul + 1
getitem: "f32[10]" = inner[0]; inner = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
sin_1: "f32[10]" = torch.ops.aten.sin.default(getitem)
return (sin_1, primals_1, primals_2, sin, getitem)
```
Backward graph
```
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[10]", primals_2: "f32[10]", sin: "f32[10]", add: "f32[10]", tangents_1: "f32[10]"):
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
cos_1: "f32[10]" = torch.ops.aten.cos.default(primals_1); primals_1 = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
cos: "f32[10]" = torch.ops.aten.cos.default(add); add = None
mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
# No stacktrace found for following nodes
inner = torch__inductor_standalone_compile_inner(mul_1, sin, primals_2); mul_1 = sin = primals_2 = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:67 in fn, code: mul = sin * y
getitem: "f32[10]" = inner[0]
getitem_1: "f32[10]" = inner[1]; inner = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
mul_4: "f32[10]" = torch.ops.aten.mul.Tensor(getitem_1, cos_1); getitem_1 = cos_1 = None
return (mul_4, getitem)
```
### Some issue raised in the HOP meeting
1) CSE will not differentiate different meta custom nodes and do wrong thing.
2) SAC - The recomputed forward will be smaller than the forward. Will we compile a smaller region than?
3) What happens if you have a op in the middle which does not disturb the topology, is it still 1 subgraph?
4) What happens with the nesting of `fx_traceback.annotate`? Are there any ordering requirements?
5) What are we going to use the annotations for?
a) compile flex
b) streams
c) nn.Module info to organize MoE components for pipelining
d) PP stages
e) Rename graph nodes for more debugging
f) No nested regional compile
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164776
Approved by: https://github.com/SherlockNoMad
This is follow-up of #165037. It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165142
Approved by: https://github.com/albanD
Summary: Fix the edge case by allowing `call_function` nodes with no deps as graph entry (starter_nodes) in the splitter.
Test Plan:
The test shall pass in the current diff (after fix), and fail in the parent diff (before fix)
```
buck test mode/opt //glow/fb/fx/lowering:split_tests -- test_dataclass_as_graph_entry
```
Rollback Plan:
Differential Revision: D81232435
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161716
Approved by: https://github.com/ezyang
[relanding again after fixing internal build]
Summary:
This might cause some new DDEs on call sites that do not use is_contiguous_or_false() or sym_is_contiguous()
but want to find those call sites to handle this properly by calling is_contiguous_or_false() and not is_contiguous() explitly when appropriate.
I had to fix one issue after removing the implicit size oblivious reasoning. here is context
we defined in this https://github.com/pytorch/pytorch/pull/157472 sym_is_contiguous to be the function computing contiguity for dynamic shapes in c++. It returns a symbolic expression that represents contiguity and guaranteed not to throw a DDE.
when people call is_contiguous we do sym_is_contiguous().guard_bool()
when people call is_contiguous_or_false we do sym_is_contiguous().guard_or_false()
one issue not handled well was this path
```
c10::SymBool TensorImpl::sym_is_contiguous_custom(
at::MemoryFormat memory_format) const {
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
this, memory_format);
}
return sym_is_contiguous_default(memory_format);
}
```
namely if we call sym_is_contiguous_custom but we have matches_python_custom(SizesStridesPolicy::CustomStrides) return true , then we used to call is_contiguous(this, memory_format);
This used to go through the load_pyobj_interpreter and end up calling the python is_contiguous call which used implicit size oblivious reasoning.
once we removed that implicit size oblivious reasoning, the right thing we want is to call
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(this, memory_format);
otherwise we would get DDE even if the caller is doing sym_is_contiguous.
so I had to define it for pyinterpreter, and then I had to override it for nested tensors.
Approved by: https://github.com/ezyang
Test Plan:
contbuild & OSS CI, see e444cd24d4
Rollback Plan:
Differential Revision: D80435179
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160869
Approved by: https://github.com/ezyang
Summary: This PR introduces shape guards to export. Previously only value ranges, equalities, and specializations would be tracked for symbolic expressions, and we had a forward hook to check them. Instead now we create a function to check shape guards and call it in the exported program.
Test Plan:
updated several tests
Rollback Plan:
Differential Revision: D80713603
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161178
Approved by: https://github.com/tugsbayasgalan
Summary:
A tool to track events in graph split, specifically on how nodes being end up in acc or cpu subgraphs.
Usage: use env var to specify a mode and necessary arguments.
FX_NET_ACC_SPLITTER_TRACKER_MODE: Tracker mode.
```
Different modes of the event tracker:
"0": Tracker not enabled (by default)
"1": Tracker enabled but no dumps. Information available by setting breakpoints and visually inspect in pdb.
"2": Tracker enabled and dumps all events to DUMP_PREFIX_all.txt
"3": In addition to events dump, track nodes specified by ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES recusrively and dump to DUMP_PREFIX_nodex.txt
"4:: In addition to events dump, track all nodes with more than 1 event recusrively and dump to DUMP_PREFIX_nodex.txt
```
FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH: overriding dump path. Leave empty for `~`.
FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES: Nodes to track for mode "3".
Test Plan: New unit test
Reviewed By: georgiaphillips
Differential Revision: D79203595
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159795
Approved by: https://github.com/ezyang
Summary: This change updates `getattr_recursive` to handle qualnames with ModuleList that contain digit indices, for example, `op_instances.1.value_model.feature_weights`
Test Plan:
TBA
Rollback Plan:
Reviewed By: jiayisuse
Differential Revision: D80503985
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161204
Approved by: https://github.com/jiayisuse
This might cause some new DDEs on call sites that do not use is_contiguous_or_false() or sym_is_contiguous()
but want to find those call sites to handle this properly by calling is_contiguous_or_false() and not is_contiguous() explitly when appropriate.
I had to fix one issue after removing the implicit size oblivious reasoning. here is context
we defined in this https://github.com/pytorch/pytorch/pull/157472 sym_is_contiguous to be the function computing contiguity for dynamic shapes in c++. It returns a symbolic expression that represents contiguity and guaranteed not to throw a DDE.
when people call is_contiguous we do sym_is_contiguous().guard_bool()
when people call is_contiguous_or_false we do sym_is_contiguous().guard_or_false()
one issue not handled well was this path
```
c10::SymBool TensorImpl::sym_is_contiguous_custom(
at::MemoryFormat memory_format) const {
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
this, memory_format);
}
return sym_is_contiguous_default(memory_format);
}
```
namely if we call sym_is_contiguous_custom but we have matches_python_custom(SizesStridesPolicy::CustomStrides) return true , then we used to call is_contiguous(this, memory_format);
This used to go through the load_pyobj_interpreter and end up calling the python is_contiguous call which used implicit size oblivious reasoning.
once we removed that implicit size oblivious reasoning, the right thing we want is to call
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(this, memory_format);
otherwise we would get DDE even if the caller is doing sym_is_contiguous.
so I had to define it for pyinterpreter, and then I had to override it for nested tensors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159197
Approved by: https://github.com/ezyang
Summary: as title. We've got request from various parties who are interested in turning on the provenance tracking by default. In this PR, we prepare to turn on part of the provenance tracking that doesn't have too much overhead by default.
- Change `provenance_tracking` config to `provenance_tracking_level`
- turn on the following provenance tracking by default when `basic_provenance_tracking`=True
- `set_kernel_post_grad_provenance_tracing` for kernels, this add mapping between triton kernels and post_grad nodes
- `dump_inductor_provenance_info` if we're dumping tlparse log
- `get_graph_provenance_json` and dump `reate_mapping_pre_post_grad_nodes`. This creates mapping between pre_grad and post_grad nodes. Since we're not turning on the provenance tracking in GraphTransformObserver by default, the mapping here maybe incomplete/limited.
- add stack trace from post grad nodes to inductor IR nodes
- add exception swallowing for all functions above
Test Plan:
CI
Rollback Plan:
Differential Revision: D80031559
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160383
Approved by: https://github.com/angelayi
Summary:
Add support for torch._check() in TorchScript jit.script frontend.
* It will be special cased to behave like torch._assert, turned into an if + raise exception.
Test Plan:
Unit tests
Rollback Plan:
Differential Revision: D79744604
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159988
Approved by: https://github.com/davidberard98
Automatically replaces split with rsplit when relevant and only performs the split up to the first ( or last value). This allows early return of the split function and improve efficiency.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160107
Approved by: https://github.com/albanD
Summary:
The `replace_hook` is called once for each user of the replaced node. This fix avoids adding duplicated node sources.
This also means that if there are two nested pass like:
```
with GraphTransformObserver(gm, "outer"):
with GraphTransformObserver(gm, "inner"):
.....
```
We'll only see the outer pass's pass name recorded for the replaced node in the "from_node" node meta. I think this is fine. In practice, the outer pass usually contains a more meaningful name, e.g. `decompose_auto_functionalized`, and the inner pass name is just a default pass name like `pattern_matcher`.
Test Plan:
```
buck2 run @mode/dev-nosan fbcode//caffe2/test:fx -- -r test_graph_transform_observer_replace
```
Rollback Plan:
Differential Revision: D79203058
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159484
Approved by: https://github.com/angelayi
When select has data dependent input, we cant tell if the actual index shall be index+size or index.
to avoid throwing dde, we allocate a new unbacked symbol to represent the storage offset of the
output view and we compute its value dynamically at runtime when inductor is lowered.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157605
Approved by: https://github.com/ColinPeppler
Summary:
for this particular instance, we're doing
from torch._inductor.config import trace
...trace.provenance_tracking...
but for all other call sites, we're doing
from torch._inductor import config
... config.trace.provenance_tracking....
Test Plan:
CI
Rollback Plan:
Differential Revision: D78699876
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158796
Approved by: https://github.com/c00w
When select has data dependent input, we cant tell if the actual index shall be index+size or index.
to avoid throwing dde, we allocate a new unbacked symbol to represent the storage offset of the
output view and we compute its value dynamically at runtime when inductor is lowered.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157605
Approved by: https://github.com/ColinPeppler
Summary:
As inductor provenance tracking is getting more use cases, we want to separate the inductor provenance tracking guarding flag from the general `trace.enabled`, so we can enable provenance tracking without all the overhead of `trace.enabled`
- change the guard flag from `trace.enabled` to `trace.provenance_tracking`. It is turned on by either `TORCH_COMPILE_DEBUG=1` or `INDUCTOR_PROVENANCE=1`.
- Move the provenance tracking logic and variables out of DebugContext, because DebugContext is only enabled with `trace.enabled`. Since the variables are now global variables, added `reset_provenance_globals()` context manager to reset them for each `compile_fx()` call.
- Move `set_kernel_post_grad_provenance_tracing` from `util.py` to `debug.py` so now all provenance related logic is in `debug.py`.
In the future, if we want to enable it further, we can change the provenance tracking flag to be enabled when `TORCH_TRACE` is set. I think we should do that in a separate PR, so it's easier to revert if this flag change creates any problem.
See more motivation in internal Diff
Test Plan:
```
buck2 run mode/dev-nosan fbcode//caffe2/test:fx -- -r test_graph_transform_observer
buck run mode/dev-nosan fbcode//caffe2/test:fx -- -r graph_provenance
buck2 run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing
```
Differential Revision: D78287976
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158399
Approved by: https://github.com/angelayi
This is useful for vLLM, which runs AOTAutograd directly on graphs after
they have been split.
I created a new flag for this instead of reusing
`keep_original_node_name` (please let me know if you think I should reuse this).
The reasoning is:
- The names of the placeholder nodes is different from the targets of
the placehoder nodes. The targets are the actual input names.
- Backwards compatibility: this API has been out for ~4 years, it
looks public, and it has extensive public use. For example, this change
would actually be BC-breaking to vLLM (they rely on the subgraph input
names being different at the moment).
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157733
Approved by: https://github.com/ezyang
Summary: Add a pass use_triton_fp8_swish_replace_normal_swish to replace _triton_swish_rms_norm with its counterpart that supports fp8 triton_swish_rms_norm, and turn on fp8 during inference.
Test Plan:
```
buck2 run mode/opt mode/inplace -c fbcode.platform010_cuda_version=12.4 -c fbcode.nvcc_arch=h100 caffe2/torch/fb/model_transform/experimental/benchmark:mts_gpu_benchmark -- --lower-backend=AOT_INDUCTOR --model-snapshot-id=899072727_0 --node-replacement-dict="{}" --gpu-trace --add-passes=use_triton_fp8_swish_replace_normal_swish
```
The perf improvement on the 100x model with this pass is roughly ~7%, details are recorded [here](https://docs.google.com/document/d/1eIV_OTQyQcf_DlEDxwycTwhyGxT5OJkLzs8cPL6EMYc/edit?tab=t.0)
Rollback Plan:
Reviewed By: frank-wei
Differential Revision: D76531303
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157574
Approved by: https://github.com/frank-wei
Summary:
Encountered 'register_foward_pre_hook not supported on ScriptModule' error when trying to publish CFR MTML with placing remote_ro module in remote. Issue may come from the fact that the local net from torchArrow is already scriptModule before gen_app_graph pass.
{F1979770267}
Test Plan:
hg checkout 1ff14dfaade4ac1f3cbbf38fbd72f7fdd5cdcd16
bash hstu_blocker.sh
Rollback Plan:
Reviewed By: RenfeiChen-FB
Differential Revision: D77341370
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156904
Approved by: https://github.com/jingsh
Summary:
Cloned https://github.com/pytorch/pytorch/pull/153558 from benjaminglass1 and fixed internal typing errors.
Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that `__getattr__` can return a single type in all other cases.
Decisions made along the way:
1. `torch.ops.higher_order` is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the `_Ops` class.
2. `__getattr__` is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables.
The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues.
Test Plan: CI
Differential Revision: D75497142
Co-authored-by: Benjamin Glass <bglass@quansight.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154555
Approved by: https://github.com/Skylion007, https://github.com/malfet, https://github.com/zou3519, https://github.com/benjaminglass1