mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[compile] Regional inductor compilation with fx.annotate (#164776)
This PR introduces a way to compile a region of FX graph using `fx.traceback.annotate`. ### UX 1) In the user code, mark the region that you want to be compiled with inductor using `with fx_traceback.annotate({"compile_with_inductor": 0})`. As of now, we just rely on the string `compile_with_inductor` and ignore the integer. As the needs arise, we can update the logic. Example ``` def fn(x, y): sin = torch.sin(x) with fx_traceback.annotate({"compile_with_inductor": 0}): mul = sin * y add = mul + 1 return torch.sin(add) ``` 2) You have to instruct the compiler to use the annotations with `compile_fx_annotated_nodes_with_inductor` transformation. This is somewhat controversial, and a user might expect that just setting annotation is enough. But for now to control the blast radius, we need to explicitly do this. One such example is ``` # Set the fw and bw compiler of aot_autograd to `compile_fx_annotated_nodes_with_inductor` def aot_eager_regional_inductor(): return aot_autograd( fw_compiler=compile_fx_annotated_nodes_with_inductor, bw_compiler=compile_fx_annotated_nodes_with_inductor, ) ``` 3) Fixable in short-term - You have to wrap the user code in `torch.fx.traceback.preserve_node_meta` to ensure that annotations are propagated to the compiler. This is fixable, just need to make CI happy. ### Implementation 1) Relies on `CapabilityBasedPartitioner` to "scoop" out regions based on annotations, and then create subgraphs in the main graph. 2) Call `torch._inductor.standalone_compile` on these subgraphs, and jam the returned callable into the FX graph at the place of call_module Resulting graph looks something like this - search for `torch__inductor_standalone_compile_inner` Forward graph ``` class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[10]", primals_2: "f32[10]"): # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x) sin: "f32[10]" = torch.ops.aten.sin.default(primals_1) # No stacktrace found for following nodes inner = torch__inductor_standalone_compile_inner(sin, primals_2) # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:68 in fn, code: add = mul + 1 getitem: "f32[10]" = inner[0]; inner = None # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add) sin_1: "f32[10]" = torch.ops.aten.sin.default(getitem) return (sin_1, primals_1, primals_2, sin, getitem) ``` Backward graph ``` class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[10]", primals_2: "f32[10]", sin: "f32[10]", add: "f32[10]", tangents_1: "f32[10]"): # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x) cos_1: "f32[10]" = torch.ops.aten.cos.default(primals_1); primals_1 = None # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add) cos: "f32[10]" = torch.ops.aten.cos.default(add); add = None mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None # No stacktrace found for following nodes inner = torch__inductor_standalone_compile_inner(mul_1, sin, primals_2); mul_1 = sin = primals_2 = None # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:67 in fn, code: mul = sin * y getitem: "f32[10]" = inner[0] getitem_1: "f32[10]" = inner[1]; inner = None # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x) mul_4: "f32[10]" = torch.ops.aten.mul.Tensor(getitem_1, cos_1); getitem_1 = cos_1 = None return (mul_4, getitem) ``` ### Some issue raised in the HOP meeting 1) CSE will not differentiate different meta custom nodes and do wrong thing. 2) SAC - The recomputed forward will be smaller than the forward. Will we compile a smaller region than? 3) What happens if you have a op in the middle which does not disturb the topology, is it still 1 subgraph? 4) What happens with the nesting of `fx_traceback.annotate`? Are there any ordering requirements? 5) What are we going to use the annotations for? a) compile flex b) streams c) nn.Module info to organize MoE components for pipelining d) PP stages e) Rename graph nodes for more debugging f) No nested regional compile Pull Request resolved: https://github.com/pytorch/pytorch/pull/164776 Approved by: https://github.com/SherlockNoMad ghstack dependencies: #165188
This commit is contained in:
committed by
PyTorch MergeBot
parent
1191e51c44
commit
f3683453ae
@ -1019,6 +1019,8 @@ coverage_ignore_functions = [
|
||||
"loop_pass",
|
||||
"these_before_those_pass_constraint",
|
||||
"this_before_that_pass_constraint",
|
||||
# torch.fx.passes.regional_inductor
|
||||
"regional_inductor",
|
||||
# torch.fx.passes.reinplace
|
||||
"reinplace",
|
||||
# torch.fx.passes.split_module
|
||||
|
@ -1169,6 +1169,7 @@ The set of leaf modules can be customized by overriding
|
||||
.. py:module:: torch.fx.passes.operator_support
|
||||
.. py:module:: torch.fx.passes.param_fetch
|
||||
.. py:module:: torch.fx.passes.pass_manager
|
||||
.. py:module:: torch.fx.passes.regional_inductor
|
||||
.. py:module:: torch.fx.passes.reinplace
|
||||
.. py:module:: torch.fx.passes.runtime_assert
|
||||
.. py:module:: torch.fx.passes.shape_prop
|
||||
|
284
test/dynamo/test_regional_inductor.py
Normal file
284
test/dynamo/test_regional_inductor.py
Normal file
@ -0,0 +1,284 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import torch._inductor.test_case
|
||||
import torch.fx.traceback as fx_traceback
|
||||
import torch.utils.checkpoint
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._inductor.test_case import run_tests
|
||||
from torch._inductor.utils import run_fw_bw_and_get_code
|
||||
from torch.fx.passes.regional_inductor import regional_inductor
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
|
||||
|
||||
# Open questions / follow-ups
|
||||
# 1) CSE behavior with meta custom nodes
|
||||
# Common subexpression elimination may not differentiate between distinct meta
|
||||
# custom nodes and could remove expressions, which might confuse users.
|
||||
#
|
||||
# 2) SAC: recompute vs. forward size
|
||||
# If the recomputed forward is smaller than the original forward, do we end up
|
||||
# compiling only the smaller region?
|
||||
#
|
||||
# 3) fx_traceback.annotate nesting
|
||||
# How does nesting behave? Are there any ordering requirements?
|
||||
#
|
||||
# 4) Planned uses for annotations
|
||||
# a) compile flex
|
||||
# b) streams
|
||||
# c) nn.Module info to organize MoE runtime
|
||||
# d) pipeline-parallel stages
|
||||
# e) rename graph nodes for easier debugging
|
||||
# f) disallow nested regional compile
|
||||
|
||||
|
||||
def aot_eager_regional_inductor():
|
||||
return aot_autograd(
|
||||
fw_compiler=regional_inductor,
|
||||
bw_compiler=regional_inductor,
|
||||
)
|
||||
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable dynamo wrapped test")
|
||||
class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
def test_simple(self):
|
||||
def fn(x, y):
|
||||
sin = torch.sin(x)
|
||||
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
mul = sin * y
|
||||
add = mul + 1
|
||||
|
||||
return torch.sin(add)
|
||||
|
||||
opt_fn = torch.compile(
|
||||
fn, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
y = torch.randn(10, requires_grad=True)
|
||||
|
||||
# Check that inductor compilation is called twice
|
||||
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x, y))
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
def test_repeated_blocks(self):
|
||||
def fn(x, y):
|
||||
sin = torch.sin(x)
|
||||
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
mul = sin * y
|
||||
add = mul + 1
|
||||
|
||||
return torch.sin(add)
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
a = fn(x, y)
|
||||
return fn(a, y)
|
||||
|
||||
mod = Mod()
|
||||
|
||||
opt_mod = torch.compile(
|
||||
mod, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
y = torch.randn(10, requires_grad=True)
|
||||
|
||||
# Check that inductor compilation is called 4 times
|
||||
# there will be 2 partitions in the fwd and 2 in the bwd, totalling 4
|
||||
_, codes = run_fw_bw_and_get_code(lambda: opt_mod(x, y))
|
||||
self.assertEqual(len(codes), 4)
|
||||
|
||||
def test_invoke_subgraph(self):
|
||||
# Checks that get_attr nodes custom metadata is propagated
|
||||
@torch.compiler.nested_compile_region
|
||||
def gn(x):
|
||||
return torch.sin(x)
|
||||
|
||||
def fn(x):
|
||||
x = x + 1
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
z = gn(x)
|
||||
return torch.sigmoid(z)
|
||||
|
||||
opt_fn = torch.compile(
|
||||
fn, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
|
||||
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
def test_invoke_subgraph_inner(self):
|
||||
# Checks that the inductor regions are searched recursively.
|
||||
@torch.compiler.nested_compile_region
|
||||
def gn(x):
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
return torch.sin(x)
|
||||
|
||||
def fn(x):
|
||||
x = x + 1
|
||||
x = gn(x)
|
||||
x = x + 1
|
||||
x = gn(x)
|
||||
return torch.sigmoid(x)
|
||||
|
||||
opt_fn = torch.compile(
|
||||
fn, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
|
||||
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
|
||||
# the invoke_subgraph is called twice - but the inside code is compiled
|
||||
# once - so in total 2 (1 fwd + 1 bwd)
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_flex_attention(self):
|
||||
def _squared(score, b, h, m, n):
|
||||
return score * score
|
||||
|
||||
def mask_mod(b, h, q, k):
|
||||
return q >= 0
|
||||
|
||||
a = 12
|
||||
b = 64
|
||||
block_mask = create_block_mask(mask_mod, None, None, a * b, a * b)
|
||||
|
||||
def fn(x):
|
||||
x = torch.sin(x)
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
x = flex_attention(x, x, x, block_mask=block_mask, score_mod=_squared)
|
||||
return torch.cos(x)
|
||||
|
||||
x = torch.randn(
|
||||
1,
|
||||
1,
|
||||
a * b,
|
||||
b,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
requires_grad=True,
|
||||
)
|
||||
|
||||
opt_fn = torch.compile(
|
||||
fn,
|
||||
backend=aot_eager_regional_inductor(),
|
||||
fullgraph=True,
|
||||
)
|
||||
|
||||
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
|
||||
# flex in forward and flex_backward in backward
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_selective_ac_flex(self):
|
||||
class FlexAttentionModule(torch.nn.Module):
|
||||
def __init__(self, hidden_size, num_heads):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
|
||||
# In-projections (query, key, value)
|
||||
self.q_proj = torch.nn.Linear(hidden_size, hidden_size)
|
||||
self.k_proj = torch.nn.Linear(hidden_size, hidden_size)
|
||||
self.v_proj = torch.nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
# Out-projection
|
||||
self.out_proj = torch.nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, seq_len, _ = x.size()
|
||||
|
||||
# Project queries, keys, and values
|
||||
q = (
|
||||
self.q_proj(x)
|
||||
.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
k = (
|
||||
self.k_proj(x)
|
||||
.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
v = (
|
||||
self.v_proj(x)
|
||||
.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
# Apply flex attention
|
||||
with torch.fx.traceback.annotate({"compile_with_inductor": 0}):
|
||||
attn_output = flex_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
)
|
||||
|
||||
# Reshape output
|
||||
attn_output = (
|
||||
attn_output.transpose(1, 2)
|
||||
.contiguous()
|
||||
.view(batch_size, seq_len, self.hidden_size)
|
||||
)
|
||||
|
||||
# Out projection
|
||||
output = self.out_proj(attn_output)
|
||||
|
||||
return output
|
||||
|
||||
from torch.utils.checkpoint import (
|
||||
checkpoint,
|
||||
create_selective_checkpoint_contexts,
|
||||
)
|
||||
|
||||
ops_to_save = [
|
||||
torch.ops.aten.mm.default,
|
||||
]
|
||||
context_fn = functools.partial(
|
||||
create_selective_checkpoint_contexts, ops_to_save
|
||||
)
|
||||
|
||||
# Define a model that uses FlexAttention with selective activation checkpointing
|
||||
class SacModule(torch.nn.Module):
|
||||
def __init__(self, hidden_size, num_heads, context_fn):
|
||||
super().__init__()
|
||||
self.flex_attn = FlexAttentionModule(hidden_size, num_heads)
|
||||
self.context_fn = context_fn
|
||||
|
||||
def forward(self, x):
|
||||
def flex_attn_fn(x):
|
||||
return self.flex_attn(x)
|
||||
|
||||
output = checkpoint(
|
||||
flex_attn_fn,
|
||||
x,
|
||||
use_reentrant=False,
|
||||
context_fn=self.context_fn,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
flex_module = SacModule(hidden_size=512, num_heads=8, context_fn=context_fn).to(
|
||||
"cuda", dtype=torch.bfloat16
|
||||
)
|
||||
x = torch.ones(8, 1024, 512, device="cuda", dtype=torch.bfloat16)
|
||||
compiled_module = torch.compile(
|
||||
flex_module, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
)
|
||||
|
||||
_, codes = run_fw_bw_and_get_code(lambda: compiled_module(x))
|
||||
# flex in forward and flex_backward in backward
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -340,15 +340,12 @@ class GraphModule(torch.nn.Module):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[8]", primals_2: "f32[8]", primals_3: "f32[8]"):
|
||||
partitioned_fw_subgraph_0_0 = self.partitioned_fw_subgraph_0_0
|
||||
|
||||
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_0, 'partitioned_fw_subgraph_0_0', primals_1, primals_2, primals_3); partitioned_fw_subgraph_0_0 = None
|
||||
getitem_12: "f32[8]" = invoke_subgraph_4[3]
|
||||
getitem_11: "f32[8]" = invoke_subgraph_4[2]
|
||||
getitem_10: "f32[8]" = invoke_subgraph_4[1]
|
||||
getitem: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
|
||||
|
||||
partitioned_fw_subgraph_0_1 = self.partitioned_fw_subgraph_0_0
|
||||
|
||||
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_1, 'partitioned_fw_subgraph_0_0', primals_1, primals_2, primals_3); partitioned_fw_subgraph_0_1 = primals_1 = primals_2 = primals_3 = None
|
||||
getitem_15: "f32[8]" = invoke_subgraph_6[3]
|
||||
getitem_14: "f32[8]" = invoke_subgraph_6[2]
|
||||
@ -373,13 +370,10 @@ class GraphModule(torch.nn.Module):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, getitem_12: "f32[8]", getitem_11: "f32[8]", getitem_10: "f32[8]", getitem_15: "f32[8]", getitem_14: "f32[8]", getitem_13: "f32[8]", tangents_1: "f32[8]"):
|
||||
partitioned_bw_subgraph_0_1 = self.partitioned_bw_subgraph_0_0
|
||||
|
||||
invoke_subgraph_7 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_1, 'partitioned_bw_subgraph_0_0', getitem_13, getitem_14, getitem_15, tangents_1); partitioned_bw_subgraph_0_1 = getitem_13 = getitem_14 = getitem_15 = None
|
||||
getitem_2: "f32[8]" = invoke_subgraph_7[0]
|
||||
getitem_3: "f32[8]" = invoke_subgraph_7[1]; invoke_subgraph_7 = None
|
||||
|
||||
partitioned_bw_subgraph_0_0 = self.partitioned_bw_subgraph_0_0
|
||||
|
||||
invoke_subgraph_5 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_0, 'partitioned_bw_subgraph_0_0', getitem_10, getitem_11, getitem_12, tangents_1); partitioned_bw_subgraph_0_0 = getitem_10 = getitem_11 = getitem_12 = tangents_1 = None
|
||||
getitem_6: "f32[8]" = invoke_subgraph_5[0]
|
||||
getitem_7: "f32[8]" = invoke_subgraph_5[1]; invoke_subgraph_5 = None
|
||||
@ -657,14 +651,11 @@ class GraphModule(torch.nn.Module):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[8]"):
|
||||
partitioned_fw_subgraph_0_0 = self.partitioned_fw_subgraph_0_0
|
||||
|
||||
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_0, 'partitioned_fw_subgraph_0_0', primals_1); partitioned_fw_subgraph_0_0 = None
|
||||
getitem_7: "b8[8]" = invoke_subgraph_4[2]
|
||||
getitem_6: "f32[8]" = invoke_subgraph_4[1]
|
||||
getitem: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
|
||||
|
||||
partitioned_fw_subgraph_1_0 = self.partitioned_fw_subgraph_1_0
|
||||
|
||||
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_1_0, 'partitioned_fw_subgraph_1_0', primals_1); partitioned_fw_subgraph_1_0 = primals_1 = None
|
||||
getitem_8: "f32[8]" = invoke_subgraph_6[1]
|
||||
getitem_1: "f32[8]" = invoke_subgraph_6[0]; invoke_subgraph_6 = None
|
||||
@ -798,14 +789,12 @@ class GraphModule(torch.nn.Module):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[8]", primals_2: "f32[8]"):
|
||||
partitioned_fw_subgraph_0_0 = self.partitioned_fw_subgraph_0_0
|
||||
|
||||
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_0, 'partitioned_fw_subgraph_0_0', primals_1, primals_2); partitioned_fw_subgraph_0_0 = primals_1 = None
|
||||
getitem_9: "f32[8]" = invoke_subgraph_4[2]
|
||||
getitem_8: "f32[8]" = invoke_subgraph_4[1]
|
||||
getitem: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
|
||||
|
||||
partitioned_fw_subgraph_0_1 = self.partitioned_fw_subgraph_0_0
|
||||
|
||||
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_1, 'partitioned_fw_subgraph_0_0', getitem, primals_2); partitioned_fw_subgraph_0_1 = getitem = primals_2 = None
|
||||
getitem_11: "f32[8]" = invoke_subgraph_6[2]
|
||||
getitem_10: "f32[8]" = invoke_subgraph_6[1]
|
||||
@ -1517,7 +1506,6 @@ class GraphModule(torch.nn.Module):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[8, 8]"):
|
||||
partitioned_fw_subgraph_0_0 = self.partitioned_fw_subgraph_0_0
|
||||
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_0, 'partitioned_fw_subgraph_0_0', primals_1); partitioned_fw_subgraph_0_0 = primals_1 = None
|
||||
getitem: "f32[8, 8]" = invoke_subgraph_2[0]
|
||||
getitem_1: "f32[8, 8]" = invoke_subgraph_2[1]; invoke_subgraph_2 = None
|
||||
@ -1539,7 +1527,6 @@ class GraphModule(torch.nn.Module):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, tangents_1: "f32[8, 8]"):
|
||||
partitioned_bw_subgraph_0_0 = self.partitioned_bw_subgraph_0_0
|
||||
|
||||
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_0, 'partitioned_bw_subgraph_0_0', tangents_1, tangents_1); partitioned_bw_subgraph_0_0 = tangents_1 = None
|
||||
getitem_2: "f32[8, 8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None
|
||||
return (getitem_2,)
|
||||
@ -1678,7 +1665,6 @@ class GraphModule(torch.nn.Module):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[8, 8]", primals_2: "f32[8, 8]"):
|
||||
partitioned_fw_subgraph_0_0 = self.partitioned_fw_subgraph_0_0
|
||||
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_0, 'partitioned_fw_subgraph_0_0', primals_1, primals_2); partitioned_fw_subgraph_0_0 = primals_1 = primals_2 = None
|
||||
getitem_6: "f32[8, 8]" = invoke_subgraph_2[3]
|
||||
getitem_5: "f32[8, 8]" = invoke_subgraph_2[2]
|
||||
@ -1709,7 +1695,6 @@ class GraphModule(torch.nn.Module):
|
||||
mul: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
|
||||
|
||||
partitioned_bw_subgraph_0_0 = self.partitioned_bw_subgraph_0_0
|
||||
|
||||
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_0, 'partitioned_bw_subgraph_0_0', getitem_4, getitem_5, getitem_6, mul); partitioned_bw_subgraph_0_0 = getitem_4 = getitem_5 = getitem_6 = mul = None
|
||||
getitem_1: "f32[8, 8]" = invoke_subgraph_3[0]
|
||||
getitem_2: "f32[8, 8]" = invoke_subgraph_3[1]; invoke_subgraph_3 = None
|
||||
@ -2256,14 +2241,12 @@ class GraphModule(torch.nn.Module):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s77)", primals_2: "f32[s77, 16]"):
|
||||
partitioned_fw_subgraph_0_1 = self.partitioned_fw_subgraph_0_1
|
||||
|
||||
invoke_subgraph_8 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_1, 'partitioned_fw_subgraph_0_1', primals_1, primals_2); partitioned_fw_subgraph_0_1 = primals_2 = None
|
||||
getitem_17: "Sym(s77)" = invoke_subgraph_8[2]
|
||||
getitem_16: "f32[s77, 16]" = invoke_subgraph_8[1]
|
||||
getitem: "f32[s77, 16]" = invoke_subgraph_8[0]; invoke_subgraph_8 = None
|
||||
|
||||
partitioned_fw_subgraph_0_2 = self.partitioned_fw_subgraph_0_1
|
||||
|
||||
invoke_subgraph_10 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_2, 'partitioned_fw_subgraph_0_1', primals_1, getitem); partitioned_fw_subgraph_0_2 = getitem = None
|
||||
getitem_19: "Sym(s77)" = invoke_subgraph_10[2]
|
||||
getitem_18: "f32[s77, 16]" = invoke_subgraph_10[1]
|
||||
@ -2272,14 +2255,12 @@ class GraphModule(torch.nn.Module):
|
||||
sin: "f32[s77, 16]" = torch.ops.aten.sin.default(getitem_1)
|
||||
|
||||
partitioned_fw_subgraph_0_3 = self.partitioned_fw_subgraph_0_1
|
||||
|
||||
invoke_subgraph_12 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_3, 'partitioned_fw_subgraph_0_1', primals_1, sin); partitioned_fw_subgraph_0_3 = sin = None
|
||||
getitem_21: "Sym(s77)" = invoke_subgraph_12[2]
|
||||
getitem_20: "f32[s77, 16]" = invoke_subgraph_12[1]
|
||||
getitem_2: "f32[s77, 16]" = invoke_subgraph_12[0]; invoke_subgraph_12 = None
|
||||
|
||||
partitioned_fw_subgraph_0_0 = self.partitioned_fw_subgraph_0_0
|
||||
|
||||
invoke_subgraph_14 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_0, 'partitioned_fw_subgraph_0_0', primals_1, getitem_2); partitioned_fw_subgraph_0_0 = None
|
||||
getitem_23: "Sym(s77)" = invoke_subgraph_14[2]
|
||||
getitem_22: "f32[s77, 16]" = invoke_subgraph_14[1]
|
||||
@ -2311,26 +2292,22 @@ class GraphModule(torch.nn.Module):
|
||||
expand: "f32[s77, 16]" = torch.ops.aten.expand.default(tangents_1, [primals_1, 16]); tangents_1 = primals_1 = None
|
||||
|
||||
partitioned_bw_subgraph_0_0 = self.partitioned_bw_subgraph_0_0
|
||||
|
||||
invoke_subgraph_15 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_0, 'partitioned_bw_subgraph_0_0', getitem_23, getitem_22, expand); partitioned_bw_subgraph_0_0 = getitem_23 = getitem_22 = None
|
||||
getitem_5: "f32[s77, 16]" = invoke_subgraph_15[1]; invoke_subgraph_15 = None
|
||||
|
||||
add_16: "f32[s77, 16]" = torch.ops.aten.add.Tensor(expand, getitem_5); expand = getitem_5 = None
|
||||
|
||||
partitioned_bw_subgraph_0_3 = self.partitioned_bw_subgraph_0_1
|
||||
|
||||
invoke_subgraph_13 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_3, 'partitioned_bw_subgraph_0_1', getitem_21, getitem_20, add_16); partitioned_bw_subgraph_0_3 = getitem_21 = getitem_20 = add_16 = None
|
||||
getitem_8: "f32[s77, 16]" = invoke_subgraph_13[1]; invoke_subgraph_13 = None
|
||||
|
||||
mul_10: "f32[s77, 16]" = torch.ops.aten.mul.Tensor(getitem_8, cos); getitem_8 = cos = None
|
||||
|
||||
partitioned_bw_subgraph_0_2 = self.partitioned_bw_subgraph_0_1
|
||||
|
||||
invoke_subgraph_11 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_2, 'partitioned_bw_subgraph_0_1', getitem_19, getitem_18, mul_10); partitioned_bw_subgraph_0_2 = getitem_19 = getitem_18 = mul_10 = None
|
||||
getitem_11: "f32[s77, 16]" = invoke_subgraph_11[1]; invoke_subgraph_11 = None
|
||||
|
||||
partitioned_bw_subgraph_0_1 = self.partitioned_bw_subgraph_0_1
|
||||
|
||||
invoke_subgraph_9 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_1, 'partitioned_bw_subgraph_0_1', getitem_17, getitem_16, getitem_11); partitioned_bw_subgraph_0_1 = getitem_17 = getitem_16 = getitem_11 = None
|
||||
getitem_14: "f32[s77, 16]" = invoke_subgraph_9[1]; invoke_subgraph_9 = None
|
||||
return (None, getitem_14)
|
||||
|
@ -854,6 +854,7 @@ def run_joint_graph_passes_on_hops(
|
||||
with joint_gm.graph.inserting_after(fw_node):
|
||||
new_fw_mod_attr_name = add_new_hop_gm(new_fw_hop_gm, f"fw{identifier}")
|
||||
new_fw_mod_attr = joint_gm.graph.get_attr(new_fw_mod_attr_name)
|
||||
new_fw_mod_attr.meta = copy.copy(fw_node.args[0].meta)
|
||||
|
||||
# new_hop_fw_gm output signature is (*fw_outs, *saved_tensors)
|
||||
with joint_gm.graph.inserting_after(new_fw_mod_attr):
|
||||
@ -906,6 +907,7 @@ def run_joint_graph_passes_on_hops(
|
||||
with joint_gm.graph.inserting_after(bw_node):
|
||||
new_bw_mod_attr_name = add_new_hop_gm(new_bw_hop_gm, bw_node.args[1])
|
||||
new_bw_mod_attr = joint_gm.graph.get_attr(new_bw_mod_attr_name)
|
||||
new_bw_mod_attr.meta = copy.copy(bw_node.args[0].meta)
|
||||
|
||||
with joint_gm.graph.inserting_after(new_bw_mod_attr):
|
||||
new_bw_node = joint_gm.graph.call_function(
|
||||
|
@ -4,6 +4,7 @@ from . import (
|
||||
net_min_base,
|
||||
operator_support,
|
||||
param_fetch,
|
||||
regional_inductor,
|
||||
reinplace,
|
||||
runtime_assert,
|
||||
shape_prop,
|
||||
|
133
torch/fx/passes/regional_inductor.py
Normal file
133
torch/fx/passes/regional_inductor.py
Normal file
@ -0,0 +1,133 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
import functools
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch.fx._compatibility import compatibility
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["regional_inductor"]
|
||||
|
||||
|
||||
# standalone_inductor returns a callable class object - this does not sit well
|
||||
# with Fx graph node op call_function which expects a function. So this is just
|
||||
# a wrapper function to make Fx graph codegen happy.
|
||||
def _dummy_wrapper(fn):
|
||||
@functools.wraps(fn)
|
||||
def inner(*args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def _partition_by_supported_nodes(gm, supported_ops, prefix):
|
||||
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
|
||||
from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
|
||||
|
||||
partitioner = CapabilityBasedPartitioner(
|
||||
gm, supported_ops, allows_single_node_partition=True
|
||||
)
|
||||
|
||||
candidate_partitions = partitioner.propose_partitions()
|
||||
partitioned_gm = fuse_by_partitions(
|
||||
partitioner.graph_module,
|
||||
[partition.nodes for partition in candidate_partitions],
|
||||
prefix=prefix,
|
||||
always_return_tuple=True,
|
||||
)
|
||||
|
||||
return partitioned_gm
|
||||
|
||||
|
||||
def _compile_submod(gm, prefix):
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_module" and node.target.startswith(prefix):
|
||||
fake_inputs = []
|
||||
for inp_node in node.all_input_nodes:
|
||||
if hasattr(inp_node, "meta") and "val" in inp_node.meta:
|
||||
fake_inputs.append(inp_node.meta["val"])
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Partition is bad because non fake tensor value is seen {inp_node}"
|
||||
)
|
||||
|
||||
submod = getattr(gm, node.target)
|
||||
|
||||
# _dummy_wrapper is to make call_function happy
|
||||
compiled_submod = _dummy_wrapper(
|
||||
torch._inductor.standalone_compile(
|
||||
submod, fake_inputs, dynamic_shapes="from_tracing_context"
|
||||
)
|
||||
)
|
||||
|
||||
with gm.graph.inserting_after(node):
|
||||
new_node = gm.graph.call_function(
|
||||
compiled_submod, args=node.args, kwargs=node.kwargs
|
||||
)
|
||||
new_node.meta = node.meta
|
||||
node.replace_all_uses_with(new_node)
|
||||
gm.graph.erase_node(node)
|
||||
del gm._modules[node.target]
|
||||
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
|
||||
def _needs_inductor_compile(node):
|
||||
return (
|
||||
node.op not in ("placeholder", "output")
|
||||
and hasattr(node, "meta")
|
||||
and node.meta.get("custom", None)
|
||||
and "compile_with_inductor" in node.meta["custom"]
|
||||
)
|
||||
|
||||
|
||||
def _compile_fx_annotated_nodes_with_inductor(gm):
|
||||
from torch.fx.passes.operator_support import OperatorSupport
|
||||
|
||||
found_marked_node = False
|
||||
for node in gm.graph.nodes:
|
||||
if _needs_inductor_compile(node):
|
||||
found_marked_node = True
|
||||
break
|
||||
|
||||
if not found_marked_node:
|
||||
logger.info("No inductor marked nodes found")
|
||||
return gm
|
||||
|
||||
class InductorMarkedNodes(OperatorSupport):
|
||||
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
|
||||
return _needs_inductor_compile(node)
|
||||
|
||||
marked_nodes = InductorMarkedNodes()
|
||||
gm = _partition_by_supported_nodes(gm, marked_nodes, "__marked_inductor_submod")
|
||||
gm = _compile_submod(gm, "__marked_inductor_submod")
|
||||
return gm
|
||||
|
||||
|
||||
def _recursive_compile_fx_annotated_nodes_with_inductor(gm):
|
||||
for node in gm.graph.find_nodes(op="get_attr"):
|
||||
if _needs_inductor_compile(node):
|
||||
# If the get_attr itself is marked for compile, the outer graph will
|
||||
# take care of it. If we dont do that, we end up with nested
|
||||
# regional inductor compiles that do not work well.
|
||||
continue
|
||||
submod = getattr(gm, node.target)
|
||||
if isinstance(submod, torch.fx.GraphModule):
|
||||
_recursive_compile_fx_annotated_nodes_with_inductor(submod)
|
||||
|
||||
return _compile_fx_annotated_nodes_with_inductor(gm)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def regional_inductor(gm, *example_args):
|
||||
"""
|
||||
Scoops out inductor marked regions and compiles them with inductor.
|
||||
"""
|
||||
# fuser utils create new nodes using create_proxy which retains the seq_nr
|
||||
# metadata and cause issues
|
||||
with torch.fx.traceback.preserve_node_meta(enable=False):
|
||||
return _recursive_compile_fx_annotated_nodes_with_inductor(gm)
|
Reference in New Issue
Block a user