diff --git a/docs/source/conf.py b/docs/source/conf.py index 8b0571d2fed2..d21e67c1caad 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1019,6 +1019,8 @@ coverage_ignore_functions = [ "loop_pass", "these_before_those_pass_constraint", "this_before_that_pass_constraint", + # torch.fx.passes.regional_inductor + "regional_inductor", # torch.fx.passes.reinplace "reinplace", # torch.fx.passes.split_module diff --git a/docs/source/fx.md b/docs/source/fx.md index 8baa9589d1ac..c9c235382893 100644 --- a/docs/source/fx.md +++ b/docs/source/fx.md @@ -1169,6 +1169,7 @@ The set of leaf modules can be customized by overriding .. py:module:: torch.fx.passes.operator_support .. py:module:: torch.fx.passes.param_fetch .. py:module:: torch.fx.passes.pass_manager +.. py:module:: torch.fx.passes.regional_inductor .. py:module:: torch.fx.passes.reinplace .. py:module:: torch.fx.passes.runtime_assert .. py:module:: torch.fx.passes.shape_prop diff --git a/test/dynamo/test_regional_inductor.py b/test/dynamo/test_regional_inductor.py new file mode 100644 index 000000000000..fc31e25dce3f --- /dev/null +++ b/test/dynamo/test_regional_inductor.py @@ -0,0 +1,284 @@ +# Owner(s): ["module: dynamo"] + +import functools + +import torch +import torch._inductor.test_case +import torch.fx.traceback as fx_traceback +import torch.utils.checkpoint +from torch._dynamo.backends.common import aot_autograd +from torch._inductor.test_case import run_tests +from torch._inductor.utils import run_fw_bw_and_get_code +from torch.fx.passes.regional_inductor import regional_inductor +from torch.nn.attention.flex_attention import create_block_mask, flex_attention +from torch.testing._internal.common_utils import skipIfTorchDynamo +from torch.testing._internal.triton_utils import requires_cuda_and_triton + + +# Open questions / follow-ups +# 1) CSE behavior with meta custom nodes +# Common subexpression elimination may not differentiate between distinct meta +# custom nodes and could remove expressions, which might confuse users. +# +# 2) SAC: recompute vs. forward size +# If the recomputed forward is smaller than the original forward, do we end up +# compiling only the smaller region? +# +# 3) fx_traceback.annotate nesting +# How does nesting behave? Are there any ordering requirements? +# +# 4) Planned uses for annotations +# a) compile flex +# b) streams +# c) nn.Module info to organize MoE runtime +# d) pipeline-parallel stages +# e) rename graph nodes for easier debugging +# f) disallow nested regional compile + + +def aot_eager_regional_inductor(): + return aot_autograd( + fw_compiler=regional_inductor, + bw_compiler=regional_inductor, + ) + + +@skipIfTorchDynamo("Not a suitable dynamo wrapped test") +class RegionalInductorTests(torch._inductor.test_case.TestCase): + def test_simple(self): + def fn(x, y): + sin = torch.sin(x) + + with fx_traceback.annotate({"compile_with_inductor": 0}): + mul = sin * y + add = mul + 1 + + return torch.sin(add) + + opt_fn = torch.compile( + fn, backend=aot_eager_regional_inductor(), fullgraph=True + ) + x = torch.randn(10, requires_grad=True) + y = torch.randn(10, requires_grad=True) + + # Check that inductor compilation is called twice + _, codes = run_fw_bw_and_get_code(lambda: opt_fn(x, y)) + self.assertEqual(len(codes), 2) + + def test_repeated_blocks(self): + def fn(x, y): + sin = torch.sin(x) + + with fx_traceback.annotate({"compile_with_inductor": 0}): + mul = sin * y + add = mul + 1 + + return torch.sin(add) + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + a = fn(x, y) + return fn(a, y) + + mod = Mod() + + opt_mod = torch.compile( + mod, backend=aot_eager_regional_inductor(), fullgraph=True + ) + x = torch.randn(10, requires_grad=True) + y = torch.randn(10, requires_grad=True) + + # Check that inductor compilation is called 4 times + # there will be 2 partitions in the fwd and 2 in the bwd, totalling 4 + _, codes = run_fw_bw_and_get_code(lambda: opt_mod(x, y)) + self.assertEqual(len(codes), 4) + + def test_invoke_subgraph(self): + # Checks that get_attr nodes custom metadata is propagated + @torch.compiler.nested_compile_region + def gn(x): + return torch.sin(x) + + def fn(x): + x = x + 1 + with fx_traceback.annotate({"compile_with_inductor": 0}): + z = gn(x) + return torch.sigmoid(z) + + opt_fn = torch.compile( + fn, backend=aot_eager_regional_inductor(), fullgraph=True + ) + x = torch.randn(10, requires_grad=True) + + _, codes = run_fw_bw_and_get_code(lambda: opt_fn(x)) + self.assertEqual(len(codes), 2) + + def test_invoke_subgraph_inner(self): + # Checks that the inductor regions are searched recursively. + @torch.compiler.nested_compile_region + def gn(x): + with fx_traceback.annotate({"compile_with_inductor": 0}): + return torch.sin(x) + + def fn(x): + x = x + 1 + x = gn(x) + x = x + 1 + x = gn(x) + return torch.sigmoid(x) + + opt_fn = torch.compile( + fn, backend=aot_eager_regional_inductor(), fullgraph=True + ) + x = torch.randn(10, requires_grad=True) + + _, codes = run_fw_bw_and_get_code(lambda: opt_fn(x)) + # the invoke_subgraph is called twice - but the inside code is compiled + # once - so in total 2 (1 fwd + 1 bwd) + self.assertEqual(len(codes), 2) + + @requires_cuda_and_triton + def test_flex_attention(self): + def _squared(score, b, h, m, n): + return score * score + + def mask_mod(b, h, q, k): + return q >= 0 + + a = 12 + b = 64 + block_mask = create_block_mask(mask_mod, None, None, a * b, a * b) + + def fn(x): + x = torch.sin(x) + with fx_traceback.annotate({"compile_with_inductor": 0}): + x = flex_attention(x, x, x, block_mask=block_mask, score_mod=_squared) + return torch.cos(x) + + x = torch.randn( + 1, + 1, + a * b, + b, + dtype=torch.bfloat16, + device="cuda", + requires_grad=True, + ) + + opt_fn = torch.compile( + fn, + backend=aot_eager_regional_inductor(), + fullgraph=True, + ) + + _, codes = run_fw_bw_and_get_code(lambda: opt_fn(x)) + # flex in forward and flex_backward in backward + self.assertEqual(len(codes), 2) + + @requires_cuda_and_triton + def test_selective_ac_flex(self): + class FlexAttentionModule(torch.nn.Module): + def __init__(self, hidden_size, num_heads): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + + # In-projections (query, key, value) + self.q_proj = torch.nn.Linear(hidden_size, hidden_size) + self.k_proj = torch.nn.Linear(hidden_size, hidden_size) + self.v_proj = torch.nn.Linear(hidden_size, hidden_size) + + # Out-projection + self.out_proj = torch.nn.Linear(hidden_size, hidden_size) + + def forward(self, x): + batch_size, seq_len, _ = x.size() + + # Project queries, keys, and values + q = ( + self.q_proj(x) + .view(batch_size, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + k = ( + self.k_proj(x) + .view(batch_size, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + v = ( + self.v_proj(x) + .view(batch_size, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + # Apply flex attention + with torch.fx.traceback.annotate({"compile_with_inductor": 0}): + attn_output = flex_attention( + q, + k, + v, + ) + + # Reshape output + attn_output = ( + attn_output.transpose(1, 2) + .contiguous() + .view(batch_size, seq_len, self.hidden_size) + ) + + # Out projection + output = self.out_proj(attn_output) + + return output + + from torch.utils.checkpoint import ( + checkpoint, + create_selective_checkpoint_contexts, + ) + + ops_to_save = [ + torch.ops.aten.mm.default, + ] + context_fn = functools.partial( + create_selective_checkpoint_contexts, ops_to_save + ) + + # Define a model that uses FlexAttention with selective activation checkpointing + class SacModule(torch.nn.Module): + def __init__(self, hidden_size, num_heads, context_fn): + super().__init__() + self.flex_attn = FlexAttentionModule(hidden_size, num_heads) + self.context_fn = context_fn + + def forward(self, x): + def flex_attn_fn(x): + return self.flex_attn(x) + + output = checkpoint( + flex_attn_fn, + x, + use_reentrant=False, + context_fn=self.context_fn, + ) + + return output + + flex_module = SacModule(hidden_size=512, num_heads=8, context_fn=context_fn).to( + "cuda", dtype=torch.bfloat16 + ) + x = torch.ones(8, 1024, 512, device="cuda", dtype=torch.bfloat16) + compiled_module = torch.compile( + flex_module, backend=aot_eager_regional_inductor(), fullgraph=True + ) + + _, codes = run_fw_bw_and_get_code(lambda: compiled_module(x)) + # flex in forward and flex_backward in backward + self.assertEqual(len(codes), 2) + + +if __name__ == "__main__": + run_tests() diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index 0922cb64ef88..ffbefe5cd9b4 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -340,15 +340,12 @@ class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[8]", primals_2: "f32[8]", primals_3: "f32[8]"): partitioned_fw_subgraph_0_0 = self.partitioned_fw_subgraph_0_0 - invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_0, 'partitioned_fw_subgraph_0_0', primals_1, primals_2, primals_3); partitioned_fw_subgraph_0_0 = None getitem_12: "f32[8]" = invoke_subgraph_4[3] getitem_11: "f32[8]" = invoke_subgraph_4[2] getitem_10: "f32[8]" = invoke_subgraph_4[1] getitem: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None - partitioned_fw_subgraph_0_1 = self.partitioned_fw_subgraph_0_0 - invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_1, 'partitioned_fw_subgraph_0_0', primals_1, primals_2, primals_3); partitioned_fw_subgraph_0_1 = primals_1 = primals_2 = primals_3 = None getitem_15: "f32[8]" = invoke_subgraph_6[3] getitem_14: "f32[8]" = invoke_subgraph_6[2] @@ -373,13 +370,10 @@ class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module): def forward(self, getitem_12: "f32[8]", getitem_11: "f32[8]", getitem_10: "f32[8]", getitem_15: "f32[8]", getitem_14: "f32[8]", getitem_13: "f32[8]", tangents_1: "f32[8]"): partitioned_bw_subgraph_0_1 = self.partitioned_bw_subgraph_0_0 - invoke_subgraph_7 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_1, 'partitioned_bw_subgraph_0_0', getitem_13, getitem_14, getitem_15, tangents_1); partitioned_bw_subgraph_0_1 = getitem_13 = getitem_14 = getitem_15 = None getitem_2: "f32[8]" = invoke_subgraph_7[0] getitem_3: "f32[8]" = invoke_subgraph_7[1]; invoke_subgraph_7 = None - partitioned_bw_subgraph_0_0 = self.partitioned_bw_subgraph_0_0 - invoke_subgraph_5 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_0, 'partitioned_bw_subgraph_0_0', getitem_10, getitem_11, getitem_12, tangents_1); partitioned_bw_subgraph_0_0 = getitem_10 = getitem_11 = getitem_12 = tangents_1 = None getitem_6: "f32[8]" = invoke_subgraph_5[0] getitem_7: "f32[8]" = invoke_subgraph_5[1]; invoke_subgraph_5 = None @@ -657,14 +651,11 @@ class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[8]"): partitioned_fw_subgraph_0_0 = self.partitioned_fw_subgraph_0_0 - invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_0, 'partitioned_fw_subgraph_0_0', primals_1); partitioned_fw_subgraph_0_0 = None getitem_7: "b8[8]" = invoke_subgraph_4[2] getitem_6: "f32[8]" = invoke_subgraph_4[1] getitem: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None - partitioned_fw_subgraph_1_0 = self.partitioned_fw_subgraph_1_0 - invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_1_0, 'partitioned_fw_subgraph_1_0', primals_1); partitioned_fw_subgraph_1_0 = primals_1 = None getitem_8: "f32[8]" = invoke_subgraph_6[1] getitem_1: "f32[8]" = invoke_subgraph_6[0]; invoke_subgraph_6 = None @@ -798,14 +789,12 @@ class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[8]", primals_2: "f32[8]"): partitioned_fw_subgraph_0_0 = self.partitioned_fw_subgraph_0_0 - invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_0, 'partitioned_fw_subgraph_0_0', primals_1, primals_2); partitioned_fw_subgraph_0_0 = primals_1 = None getitem_9: "f32[8]" = invoke_subgraph_4[2] getitem_8: "f32[8]" = invoke_subgraph_4[1] getitem: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None partitioned_fw_subgraph_0_1 = self.partitioned_fw_subgraph_0_0 - invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_1, 'partitioned_fw_subgraph_0_0', getitem, primals_2); partitioned_fw_subgraph_0_1 = getitem = primals_2 = None getitem_11: "f32[8]" = invoke_subgraph_6[2] getitem_10: "f32[8]" = invoke_subgraph_6[1] @@ -1517,7 +1506,6 @@ class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[8, 8]"): partitioned_fw_subgraph_0_0 = self.partitioned_fw_subgraph_0_0 - invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_0, 'partitioned_fw_subgraph_0_0', primals_1); partitioned_fw_subgraph_0_0 = primals_1 = None getitem: "f32[8, 8]" = invoke_subgraph_2[0] getitem_1: "f32[8, 8]" = invoke_subgraph_2[1]; invoke_subgraph_2 = None @@ -1539,7 +1527,6 @@ class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module): def forward(self, tangents_1: "f32[8, 8]"): partitioned_bw_subgraph_0_0 = self.partitioned_bw_subgraph_0_0 - invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_0, 'partitioned_bw_subgraph_0_0', tangents_1, tangents_1); partitioned_bw_subgraph_0_0 = tangents_1 = None getitem_2: "f32[8, 8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None return (getitem_2,) @@ -1678,7 +1665,6 @@ class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[8, 8]", primals_2: "f32[8, 8]"): partitioned_fw_subgraph_0_0 = self.partitioned_fw_subgraph_0_0 - invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_0, 'partitioned_fw_subgraph_0_0', primals_1, primals_2); partitioned_fw_subgraph_0_0 = primals_1 = primals_2 = None getitem_6: "f32[8, 8]" = invoke_subgraph_2[3] getitem_5: "f32[8, 8]" = invoke_subgraph_2[2] @@ -1709,7 +1695,6 @@ class GraphModule(torch.nn.Module): mul: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None partitioned_bw_subgraph_0_0 = self.partitioned_bw_subgraph_0_0 - invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_0, 'partitioned_bw_subgraph_0_0', getitem_4, getitem_5, getitem_6, mul); partitioned_bw_subgraph_0_0 = getitem_4 = getitem_5 = getitem_6 = mul = None getitem_1: "f32[8, 8]" = invoke_subgraph_3[0] getitem_2: "f32[8, 8]" = invoke_subgraph_3[1]; invoke_subgraph_3 = None @@ -2256,14 +2241,12 @@ class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module): def forward(self, primals_1: "Sym(s77)", primals_2: "f32[s77, 16]"): partitioned_fw_subgraph_0_1 = self.partitioned_fw_subgraph_0_1 - invoke_subgraph_8 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_1, 'partitioned_fw_subgraph_0_1', primals_1, primals_2); partitioned_fw_subgraph_0_1 = primals_2 = None getitem_17: "Sym(s77)" = invoke_subgraph_8[2] getitem_16: "f32[s77, 16]" = invoke_subgraph_8[1] getitem: "f32[s77, 16]" = invoke_subgraph_8[0]; invoke_subgraph_8 = None partitioned_fw_subgraph_0_2 = self.partitioned_fw_subgraph_0_1 - invoke_subgraph_10 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_2, 'partitioned_fw_subgraph_0_1', primals_1, getitem); partitioned_fw_subgraph_0_2 = getitem = None getitem_19: "Sym(s77)" = invoke_subgraph_10[2] getitem_18: "f32[s77, 16]" = invoke_subgraph_10[1] @@ -2272,14 +2255,12 @@ class GraphModule(torch.nn.Module): sin: "f32[s77, 16]" = torch.ops.aten.sin.default(getitem_1) partitioned_fw_subgraph_0_3 = self.partitioned_fw_subgraph_0_1 - invoke_subgraph_12 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_3, 'partitioned_fw_subgraph_0_1', primals_1, sin); partitioned_fw_subgraph_0_3 = sin = None getitem_21: "Sym(s77)" = invoke_subgraph_12[2] getitem_20: "f32[s77, 16]" = invoke_subgraph_12[1] getitem_2: "f32[s77, 16]" = invoke_subgraph_12[0]; invoke_subgraph_12 = None partitioned_fw_subgraph_0_0 = self.partitioned_fw_subgraph_0_0 - invoke_subgraph_14 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_0, 'partitioned_fw_subgraph_0_0', primals_1, getitem_2); partitioned_fw_subgraph_0_0 = None getitem_23: "Sym(s77)" = invoke_subgraph_14[2] getitem_22: "f32[s77, 16]" = invoke_subgraph_14[1] @@ -2311,26 +2292,22 @@ class GraphModule(torch.nn.Module): expand: "f32[s77, 16]" = torch.ops.aten.expand.default(tangents_1, [primals_1, 16]); tangents_1 = primals_1 = None partitioned_bw_subgraph_0_0 = self.partitioned_bw_subgraph_0_0 - invoke_subgraph_15 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_0, 'partitioned_bw_subgraph_0_0', getitem_23, getitem_22, expand); partitioned_bw_subgraph_0_0 = getitem_23 = getitem_22 = None getitem_5: "f32[s77, 16]" = invoke_subgraph_15[1]; invoke_subgraph_15 = None add_16: "f32[s77, 16]" = torch.ops.aten.add.Tensor(expand, getitem_5); expand = getitem_5 = None partitioned_bw_subgraph_0_3 = self.partitioned_bw_subgraph_0_1 - invoke_subgraph_13 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_3, 'partitioned_bw_subgraph_0_1', getitem_21, getitem_20, add_16); partitioned_bw_subgraph_0_3 = getitem_21 = getitem_20 = add_16 = None getitem_8: "f32[s77, 16]" = invoke_subgraph_13[1]; invoke_subgraph_13 = None mul_10: "f32[s77, 16]" = torch.ops.aten.mul.Tensor(getitem_8, cos); getitem_8 = cos = None partitioned_bw_subgraph_0_2 = self.partitioned_bw_subgraph_0_1 - invoke_subgraph_11 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_2, 'partitioned_bw_subgraph_0_1', getitem_19, getitem_18, mul_10); partitioned_bw_subgraph_0_2 = getitem_19 = getitem_18 = mul_10 = None getitem_11: "f32[s77, 16]" = invoke_subgraph_11[1]; invoke_subgraph_11 = None partitioned_bw_subgraph_0_1 = self.partitioned_bw_subgraph_0_1 - invoke_subgraph_9 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_1, 'partitioned_bw_subgraph_0_1', getitem_17, getitem_16, getitem_11); partitioned_bw_subgraph_0_1 = getitem_17 = getitem_16 = getitem_11 = None getitem_14: "f32[s77, 16]" = invoke_subgraph_9[1]; invoke_subgraph_9 = None return (None, getitem_14) diff --git a/torch/_functorch/_aot_autograd/graph_compile.py b/torch/_functorch/_aot_autograd/graph_compile.py index 2e6d8b97eebc..aac28cbabe61 100644 --- a/torch/_functorch/_aot_autograd/graph_compile.py +++ b/torch/_functorch/_aot_autograd/graph_compile.py @@ -854,6 +854,7 @@ def run_joint_graph_passes_on_hops( with joint_gm.graph.inserting_after(fw_node): new_fw_mod_attr_name = add_new_hop_gm(new_fw_hop_gm, f"fw{identifier}") new_fw_mod_attr = joint_gm.graph.get_attr(new_fw_mod_attr_name) + new_fw_mod_attr.meta = copy.copy(fw_node.args[0].meta) # new_hop_fw_gm output signature is (*fw_outs, *saved_tensors) with joint_gm.graph.inserting_after(new_fw_mod_attr): @@ -906,6 +907,7 @@ def run_joint_graph_passes_on_hops( with joint_gm.graph.inserting_after(bw_node): new_bw_mod_attr_name = add_new_hop_gm(new_bw_hop_gm, bw_node.args[1]) new_bw_mod_attr = joint_gm.graph.get_attr(new_bw_mod_attr_name) + new_bw_mod_attr.meta = copy.copy(bw_node.args[0].meta) with joint_gm.graph.inserting_after(new_bw_mod_attr): new_bw_node = joint_gm.graph.call_function( diff --git a/torch/fx/passes/__init__.py b/torch/fx/passes/__init__.py index 433d8818e259..3bcb6e1d75a1 100644 --- a/torch/fx/passes/__init__.py +++ b/torch/fx/passes/__init__.py @@ -4,6 +4,7 @@ from . import ( net_min_base, operator_support, param_fetch, + regional_inductor, reinplace, runtime_assert, shape_prop, diff --git a/torch/fx/passes/regional_inductor.py b/torch/fx/passes/regional_inductor.py new file mode 100644 index 000000000000..dfd1643513e1 --- /dev/null +++ b/torch/fx/passes/regional_inductor.py @@ -0,0 +1,133 @@ +# mypy: allow-untyped-defs + +import functools +import logging + +import torch +from torch.fx._compatibility import compatibility + + +logger = logging.getLogger(__name__) + +__all__ = ["regional_inductor"] + + +# standalone_inductor returns a callable class object - this does not sit well +# with Fx graph node op call_function which expects a function. So this is just +# a wrapper function to make Fx graph codegen happy. +def _dummy_wrapper(fn): + @functools.wraps(fn) + def inner(*args, **kwargs): + return fn(*args, **kwargs) + + return inner + + +def _partition_by_supported_nodes(gm, supported_ops, prefix): + from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner + from torch.fx.passes.utils.fuser_utils import fuse_by_partitions + + partitioner = CapabilityBasedPartitioner( + gm, supported_ops, allows_single_node_partition=True + ) + + candidate_partitions = partitioner.propose_partitions() + partitioned_gm = fuse_by_partitions( + partitioner.graph_module, + [partition.nodes for partition in candidate_partitions], + prefix=prefix, + always_return_tuple=True, + ) + + return partitioned_gm + + +def _compile_submod(gm, prefix): + for node in gm.graph.nodes: + if node.op == "call_module" and node.target.startswith(prefix): + fake_inputs = [] + for inp_node in node.all_input_nodes: + if hasattr(inp_node, "meta") and "val" in inp_node.meta: + fake_inputs.append(inp_node.meta["val"]) + else: + raise RuntimeError( + f"Partition is bad because non fake tensor value is seen {inp_node}" + ) + + submod = getattr(gm, node.target) + + # _dummy_wrapper is to make call_function happy + compiled_submod = _dummy_wrapper( + torch._inductor.standalone_compile( + submod, fake_inputs, dynamic_shapes="from_tracing_context" + ) + ) + + with gm.graph.inserting_after(node): + new_node = gm.graph.call_function( + compiled_submod, args=node.args, kwargs=node.kwargs + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + gm.graph.erase_node(node) + del gm._modules[node.target] + + gm.recompile() + return gm + + +def _needs_inductor_compile(node): + return ( + node.op not in ("placeholder", "output") + and hasattr(node, "meta") + and node.meta.get("custom", None) + and "compile_with_inductor" in node.meta["custom"] + ) + + +def _compile_fx_annotated_nodes_with_inductor(gm): + from torch.fx.passes.operator_support import OperatorSupport + + found_marked_node = False + for node in gm.graph.nodes: + if _needs_inductor_compile(node): + found_marked_node = True + break + + if not found_marked_node: + logger.info("No inductor marked nodes found") + return gm + + class InductorMarkedNodes(OperatorSupport): + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + return _needs_inductor_compile(node) + + marked_nodes = InductorMarkedNodes() + gm = _partition_by_supported_nodes(gm, marked_nodes, "__marked_inductor_submod") + gm = _compile_submod(gm, "__marked_inductor_submod") + return gm + + +def _recursive_compile_fx_annotated_nodes_with_inductor(gm): + for node in gm.graph.find_nodes(op="get_attr"): + if _needs_inductor_compile(node): + # If the get_attr itself is marked for compile, the outer graph will + # take care of it. If we dont do that, we end up with nested + # regional inductor compiles that do not work well. + continue + submod = getattr(gm, node.target) + if isinstance(submod, torch.fx.GraphModule): + _recursive_compile_fx_annotated_nodes_with_inductor(submod) + + return _compile_fx_annotated_nodes_with_inductor(gm) + + +@compatibility(is_backward_compatible=False) +def regional_inductor(gm, *example_args): + """ + Scoops out inductor marked regions and compiles them with inductor. + """ + # fuser utils create new nodes using create_proxy which retains the seq_nr + # metadata and cause issues + with torch.fx.traceback.preserve_node_meta(enable=False): + return _recursive_compile_fx_annotated_nodes_with_inductor(gm)