[Cutlass] Enable fusion with FusedSchedulerNodes (#153588)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153588
Approved by: https://github.com/eellison
ghstack dependencies: #152815
This commit is contained in:
Michael Lazos
2025-05-17 02:09:52 -07:00
committed by PyTorch MergeBot
parent f604732e2e
commit 7ebea09986
3 changed files with 68 additions and 51 deletions

View File

@ -1536,14 +1536,14 @@ class TestCutlassBackend(TestCase):
@unittest.skipIf(not SM90OrLater, "need sm_90")
@use_evt_config
@evt_all_ops
@unittest.skip("Needs fused scheduler node fusion support, (upcoming PR)")
def test_evt_multi_output(self, op):
class TestModel(torch.nn.Module):
def forward(self, a, b, extra_args):
acc = a @ b
z = op(acc.relu(), *extra_args)
y = z + 1
return acc, y
z0 = acc.relu()
z = op(z0, *extra_args)
y = z + z0
return z, y
M = 1024
N = 512
@ -1556,7 +1556,7 @@ class TestCutlassBackend(TestCase):
ref_result = model(a, b, extra_args)
self.assertEqual(
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 2
)
torch.testing.assert_close(result, ref_result)

View File

@ -64,16 +64,15 @@ class CUDACPPScheduling(BaseScheduling):
def can_fuse_vertical(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
) -> bool:
if self.is_cuda_cpp_template(node1) and isinstance(node2, SchedulerNode):
if self.is_cuda_cpp_template(node1) and isinstance(node2, BaseSchedulerNode):
assert node1.node, "node1.node should not be None"
assert node2.node, "node2.node should not be None"
return self._can_fuse_epilogue_impl(
cast(CUDATemplateBuffer, node1.node),
[],
node2, # type: ignore[arg-type]
)
elif self.is_cuda_cpp_fused_template(node1) and isinstance(
node2, SchedulerNode
node2, BaseSchedulerNode
):
assert node1.node, "node1.node should not be None"
assert node2.node, "node2.node should not be None"
@ -213,39 +212,46 @@ class CUDACPPScheduling(BaseScheduling):
"""
why = WhyNoFuseNames(cuda_template_buffer.get_name(), node_to_fuse.get_name())
ir_node_to_fuse = node_to_fuse.node
# for typing
assert ir_node_to_fuse
scheduler_nodes_to_fuse = node_to_fuse.get_nodes()
assert isinstance(cuda_template_buffer, CUDATemplateBuffer)
if not isinstance(ir_node_to_fuse, ComputedBuffer):
# Checks on constituent nodes
for s_node in scheduler_nodes_to_fuse:
node = s_node.node
if not isinstance(node, ComputedBuffer):
why(f"{node} is not a ComputedBuffer")
return False
if not isinstance(ir_node_to_fuse.data, Pointwise):
elif not isinstance(node.data, Pointwise):
why(f"{node} is not a Pointwise op")
return False
# We can fuse a Pointwise op that depends on the last fused epilogue node
# if any. If there is no epilogue node yet, it needs to depend on the template
# node
node_name = ir_node_to_fuse.get_computed_buffer_name() # type: ignore[attr-defined]
if node_name is None:
elif not node.get_computed_buffer_name(): # type: ignore[attr-defined]
why(f"{node} does not have a computed buffer name")
return False
assert (
len(existing_epilogue_nodes)
or cuda_template_buffer.get_name() in ir_node_to_fuse.get_read_names()
), "First epilogue node must read from cuda template buffer"
name = node.get_computed_buffer_name() # type: ignore[attr-defined]
# dtype can differ, and strides can differ as long as they are broadcastable
if ir_node_to_fuse.get_size() != cuda_template_buffer.get_size():
if node.get_size() != cuda_template_buffer.get_size():
why(
f"{cuda_template_buffer.get_name()}'s size: {cuda_template_buffer.get_size()} \
differs from {node_name}'s size: {ir_node_to_fuse.get_size()}"
f"{name}'s size: {node.get_size()} differs from {cuda_template_buffer.get_name()}'s \
size: {cuda_template_buffer.get_size()}"
)
return False
elif node_to_fuse.has_aliasing_or_mutation():
why(f"{node_name} has aliasing or mutation")
assert len(
existing_epilogue_nodes
) or cuda_template_buffer.get_name() in OrderedSet(
[rd.name for rd in node_to_fuse.read_writes.reads]
), "First epilogue node must read from cuda template buffer"
if node_to_fuse.has_aliasing_or_mutation():
why(f"{node_to_fuse.get_name()} has aliasing or mutation")
return False
elif node_to_fuse.is_reduction():
why(f"{node_name} is a reduction which is not yet supported by EVT")
why(
f"{node_to_fuse.get_name()} is a reduction which is not yet supported by EVT"
)
return False
elif (
not config.cuda.cutlass_epilogue_fusion_enabled
@ -264,7 +270,7 @@ differs from {node_name}'s size: {ir_node_to_fuse.get_size()}"
CutlassEVTCodegen.ir_to_evt_python_code(
cuda_template_buffer.get_name(),
existing_epilogue_nodes + [node_to_fuse],
existing_epilogue_nodes + list(node_to_fuse.get_nodes()),
OrderedSet(),
)

View File

@ -11,7 +11,7 @@ import torch._inductor.virtualized as virtualized
from torch._inductor.ir import ComputedBuffer, Pointwise
from torch._inductor.ops_handler import DefaultHandler, WrapperHandler
from torch._inductor.scheduler import BaseSchedulerNode
from torch._inductor.utils import IndentedBuffer, OrderedSet
from torch._inductor.utils import DelayReplaceLine, IndentedBuffer, OrderedSet
from torch._inductor.virtualized import OpsValue
from ...virtualized import V
@ -96,7 +96,14 @@ class _AssignmentFormatter(DefaultHandler):
return OpsValue(line)
else:
var = self.parent_handler._tmp_var()
self.parent_handler.body.writeline(f"{var} = {line}")
line = DelayReplaceLine(
var,
lambda: "D"
if var == self.parent_handler.last_stored_var_name
else var,
f"{var} = {line}",
)
self.parent_handler.body.writeline(line)
return OpsValue(var)
else:
raise NotImplementedError(name)
@ -138,6 +145,7 @@ class CutlassEVTCodegen(CutlassEVTOpsMixIn):
self.cur_node: Optional[ComputedBuffer] = None
self.name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs
self.is_D_assigned = False
self.D_var_name = None
if accumulator_node_name not in removed_buffers:
# cannot return accumulator directly, so alias it
@ -162,6 +170,8 @@ class CutlassEVTCodegen(CutlassEVTOpsMixIn):
index_vars = CutlassEVTCodegen.get_index_vars(node)
node.get_store_function()(index_vars)
codegen.finalize()
return (
codegen.get_reads(),
codegen.get_writes(),
@ -178,6 +188,17 @@ class CutlassEVTCodegen(CutlassEVTOpsMixIn):
]
)
def finalize(self) -> None:
# Rename the last store to D
# no other code references this store
# to workaround https://github.com/NVIDIA/cutlass/issues/2288
# Note: the delayed line will automatically rewrite the last assignment to
# be to D
buffer_name = self.var_name_to_buffer_name[self.last_stored_var_name]
self.var_name_to_buffer_name.pop(self.last_stored_var_name)
self.var_name_to_buffer_name["D"] = buffer_name
self.store_name_to_value[buffer_name] = OpsValue("D")
@contextmanager
def set_cur_node(self, node: ComputedBuffer) -> Generator[None, Any, Any]:
prev_node = self.cur_node
@ -213,22 +234,12 @@ class CutlassEVTCodegen(CutlassEVTOpsMixIn):
if name not in self.removed_buffers:
if index:
self._check_indexing(name, index)
value_to_write = value
if not self.is_D_assigned and name != self.accumulator_node_name:
# EVT requires an output to be named D lol
# so rename the first store to D
# accumulator cannot be assigned to D
# see https://github.com/NVIDIA/cutlass/issues/2288
self.body.writeline(f"D = {value} # cutlass evt requirement")
value_to_write = OpsValue("D")
self.is_D_assigned = True
assert value_to_write.value != _ACCUMULATOR_ARG_NAME, (
assert value.value != _ACCUMULATOR_ARG_NAME, (
"Cannot store accumulator arg name"
)
self.var_name_to_buffer_name[value_to_write.value] = name
self.store_name_to_value[name] = value_to_write
self.var_name_to_buffer_name[value.value] = name
self.store_name_to_value[name] = value
self.last_stored_var_name = value.value
return None
def _get_cur_node(self) -> ComputedBuffer: