mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR introduces a way to compile a region of FX graph using `fx.traceback.annotate`. ### UX 1) In the user code, mark the region that you want to be compiled with inductor using `with fx_traceback.annotate({"compile_with_inductor": 0})`. As of now, we just rely on the string `compile_with_inductor` and ignore the integer. As the needs arise, we can update the logic. Example ``` def fn(x, y): sin = torch.sin(x) with fx_traceback.annotate({"compile_with_inductor": 0}): mul = sin * y add = mul + 1 return torch.sin(add) ``` 2) You have to instruct the compiler to use the annotations with `compile_fx_annotated_nodes_with_inductor` transformation. This is somewhat controversial, and a user might expect that just setting annotation is enough. But for now to control the blast radius, we need to explicitly do this. One such example is ``` # Set the fw and bw compiler of aot_autograd to `compile_fx_annotated_nodes_with_inductor` def aot_eager_regional_inductor(): return aot_autograd( fw_compiler=compile_fx_annotated_nodes_with_inductor, bw_compiler=compile_fx_annotated_nodes_with_inductor, ) ``` 3) Fixable in short-term - You have to wrap the user code in `torch.fx.traceback.preserve_node_meta` to ensure that annotations are propagated to the compiler. This is fixable, just need to make CI happy. ### Implementation 1) Relies on `CapabilityBasedPartitioner` to "scoop" out regions based on annotations, and then create subgraphs in the main graph. 2) Call `torch._inductor.standalone_compile` on these subgraphs, and jam the returned callable into the FX graph at the place of call_module Resulting graph looks something like this - search for `torch__inductor_standalone_compile_inner` Forward graph ``` class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[10]", primals_2: "f32[10]"): # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x) sin: "f32[10]" = torch.ops.aten.sin.default(primals_1) # No stacktrace found for following nodes inner = torch__inductor_standalone_compile_inner(sin, primals_2) # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:68 in fn, code: add = mul + 1 getitem: "f32[10]" = inner[0]; inner = None # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add) sin_1: "f32[10]" = torch.ops.aten.sin.default(getitem) return (sin_1, primals_1, primals_2, sin, getitem) ``` Backward graph ``` class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[10]", primals_2: "f32[10]", sin: "f32[10]", add: "f32[10]", tangents_1: "f32[10]"): # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x) cos_1: "f32[10]" = torch.ops.aten.cos.default(primals_1); primals_1 = None # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add) cos: "f32[10]" = torch.ops.aten.cos.default(add); add = None mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None # No stacktrace found for following nodes inner = torch__inductor_standalone_compile_inner(mul_1, sin, primals_2); mul_1 = sin = primals_2 = None # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:67 in fn, code: mul = sin * y getitem: "f32[10]" = inner[0] getitem_1: "f32[10]" = inner[1]; inner = None # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x) mul_4: "f32[10]" = torch.ops.aten.mul.Tensor(getitem_1, cos_1); getitem_1 = cos_1 = None return (mul_4, getitem) ``` ### Some issue raised in the HOP meeting 1) CSE will not differentiate different meta custom nodes and do wrong thing. 2) SAC - The recomputed forward will be smaller than the forward. Will we compile a smaller region than? 3) What happens if you have a op in the middle which does not disturb the topology, is it still 1 subgraph? 4) What happens with the nesting of `fx_traceback.annotate`? Are there any ordering requirements? 5) What are we going to use the annotations for? a) compile flex b) streams c) nn.Module info to organize MoE components for pipelining d) PP stages e) Rename graph nodes for more debugging f) No nested regional compile Pull Request resolved: https://github.com/pytorch/pytorch/pull/164776 Approved by: https://github.com/SherlockNoMad ghstack dependencies: #165188
16 lines
263 B
Python
16 lines
263 B
Python
from . import (
|
|
graph_drawer,
|
|
graph_manipulation,
|
|
net_min_base,
|
|
operator_support,
|
|
param_fetch,
|
|
regional_inductor,
|
|
reinplace,
|
|
runtime_assert,
|
|
shape_prop,
|
|
split_module,
|
|
split_utils,
|
|
splitter_base,
|
|
tools_common,
|
|
)
|