mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Cutlass] Enable fusion with FusedSchedulerNodes (#153588)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153588 Approved by: https://github.com/eellison ghstack dependencies: #152815
This commit is contained in:
committed by
PyTorch MergeBot
parent
f604732e2e
commit
7ebea09986
@ -1536,14 +1536,14 @@ class TestCutlassBackend(TestCase):
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@use_evt_config
|
||||
@evt_all_ops
|
||||
@unittest.skip("Needs fused scheduler node fusion support, (upcoming PR)")
|
||||
def test_evt_multi_output(self, op):
|
||||
class TestModel(torch.nn.Module):
|
||||
def forward(self, a, b, extra_args):
|
||||
acc = a @ b
|
||||
z = op(acc.relu(), *extra_args)
|
||||
y = z + 1
|
||||
return acc, y
|
||||
z0 = acc.relu()
|
||||
z = op(z0, *extra_args)
|
||||
y = z + z0
|
||||
return z, y
|
||||
|
||||
M = 1024
|
||||
N = 512
|
||||
@ -1556,7 +1556,7 @@ class TestCutlassBackend(TestCase):
|
||||
ref_result = model(a, b, extra_args)
|
||||
|
||||
self.assertEqual(
|
||||
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
|
||||
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 2
|
||||
)
|
||||
torch.testing.assert_close(result, ref_result)
|
||||
|
||||
|
@ -64,16 +64,15 @@ class CUDACPPScheduling(BaseScheduling):
|
||||
def can_fuse_vertical(
|
||||
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
||||
) -> bool:
|
||||
if self.is_cuda_cpp_template(node1) and isinstance(node2, SchedulerNode):
|
||||
if self.is_cuda_cpp_template(node1) and isinstance(node2, BaseSchedulerNode):
|
||||
assert node1.node, "node1.node should not be None"
|
||||
assert node2.node, "node2.node should not be None"
|
||||
return self._can_fuse_epilogue_impl(
|
||||
cast(CUDATemplateBuffer, node1.node),
|
||||
[],
|
||||
node2, # type: ignore[arg-type]
|
||||
)
|
||||
elif self.is_cuda_cpp_fused_template(node1) and isinstance(
|
||||
node2, SchedulerNode
|
||||
node2, BaseSchedulerNode
|
||||
):
|
||||
assert node1.node, "node1.node should not be None"
|
||||
assert node2.node, "node2.node should not be None"
|
||||
@ -213,39 +212,46 @@ class CUDACPPScheduling(BaseScheduling):
|
||||
"""
|
||||
why = WhyNoFuseNames(cuda_template_buffer.get_name(), node_to_fuse.get_name())
|
||||
|
||||
ir_node_to_fuse = node_to_fuse.node
|
||||
# for typing
|
||||
assert ir_node_to_fuse
|
||||
scheduler_nodes_to_fuse = node_to_fuse.get_nodes()
|
||||
|
||||
assert isinstance(cuda_template_buffer, CUDATemplateBuffer)
|
||||
if not isinstance(ir_node_to_fuse, ComputedBuffer):
|
||||
return False
|
||||
if not isinstance(ir_node_to_fuse.data, Pointwise):
|
||||
return False
|
||||
# We can fuse a Pointwise op that depends on the last fused epilogue node
|
||||
# if any. If there is no epilogue node yet, it needs to depend on the template
|
||||
# node
|
||||
node_name = ir_node_to_fuse.get_computed_buffer_name() # type: ignore[attr-defined]
|
||||
if node_name is None:
|
||||
return False
|
||||
|
||||
assert (
|
||||
len(existing_epilogue_nodes)
|
||||
or cuda_template_buffer.get_name() in ir_node_to_fuse.get_read_names()
|
||||
# Checks on constituent nodes
|
||||
for s_node in scheduler_nodes_to_fuse:
|
||||
node = s_node.node
|
||||
|
||||
if not isinstance(node, ComputedBuffer):
|
||||
why(f"{node} is not a ComputedBuffer")
|
||||
return False
|
||||
elif not isinstance(node.data, Pointwise):
|
||||
why(f"{node} is not a Pointwise op")
|
||||
return False
|
||||
elif not node.get_computed_buffer_name(): # type: ignore[attr-defined]
|
||||
why(f"{node} does not have a computed buffer name")
|
||||
return False
|
||||
|
||||
name = node.get_computed_buffer_name() # type: ignore[attr-defined]
|
||||
# dtype can differ, and strides can differ as long as they are broadcastable
|
||||
if node.get_size() != cuda_template_buffer.get_size():
|
||||
why(
|
||||
f"{name}'s size: {node.get_size()} differs from {cuda_template_buffer.get_name()}'s \
|
||||
size: {cuda_template_buffer.get_size()}"
|
||||
)
|
||||
return False
|
||||
|
||||
assert len(
|
||||
existing_epilogue_nodes
|
||||
) or cuda_template_buffer.get_name() in OrderedSet(
|
||||
[rd.name for rd in node_to_fuse.read_writes.reads]
|
||||
), "First epilogue node must read from cuda template buffer"
|
||||
|
||||
# dtype can differ, and strides can differ as long as they are broadcastable
|
||||
if ir_node_to_fuse.get_size() != cuda_template_buffer.get_size():
|
||||
why(
|
||||
f"{cuda_template_buffer.get_name()}'s size: {cuda_template_buffer.get_size()} \
|
||||
differs from {node_name}'s size: {ir_node_to_fuse.get_size()}"
|
||||
)
|
||||
return False
|
||||
elif node_to_fuse.has_aliasing_or_mutation():
|
||||
why(f"{node_name} has aliasing or mutation")
|
||||
if node_to_fuse.has_aliasing_or_mutation():
|
||||
why(f"{node_to_fuse.get_name()} has aliasing or mutation")
|
||||
return False
|
||||
elif node_to_fuse.is_reduction():
|
||||
why(f"{node_name} is a reduction which is not yet supported by EVT")
|
||||
why(
|
||||
f"{node_to_fuse.get_name()} is a reduction which is not yet supported by EVT"
|
||||
)
|
||||
return False
|
||||
elif (
|
||||
not config.cuda.cutlass_epilogue_fusion_enabled
|
||||
@ -264,7 +270,7 @@ differs from {node_name}'s size: {ir_node_to_fuse.get_size()}"
|
||||
|
||||
CutlassEVTCodegen.ir_to_evt_python_code(
|
||||
cuda_template_buffer.get_name(),
|
||||
existing_epilogue_nodes + [node_to_fuse],
|
||||
existing_epilogue_nodes + list(node_to_fuse.get_nodes()),
|
||||
OrderedSet(),
|
||||
)
|
||||
|
||||
|
@ -11,7 +11,7 @@ import torch._inductor.virtualized as virtualized
|
||||
from torch._inductor.ir import ComputedBuffer, Pointwise
|
||||
from torch._inductor.ops_handler import DefaultHandler, WrapperHandler
|
||||
from torch._inductor.scheduler import BaseSchedulerNode
|
||||
from torch._inductor.utils import IndentedBuffer, OrderedSet
|
||||
from torch._inductor.utils import DelayReplaceLine, IndentedBuffer, OrderedSet
|
||||
from torch._inductor.virtualized import OpsValue
|
||||
|
||||
from ...virtualized import V
|
||||
@ -96,7 +96,14 @@ class _AssignmentFormatter(DefaultHandler):
|
||||
return OpsValue(line)
|
||||
else:
|
||||
var = self.parent_handler._tmp_var()
|
||||
self.parent_handler.body.writeline(f"{var} = {line}")
|
||||
line = DelayReplaceLine(
|
||||
var,
|
||||
lambda: "D"
|
||||
if var == self.parent_handler.last_stored_var_name
|
||||
else var,
|
||||
f"{var} = {line}",
|
||||
)
|
||||
self.parent_handler.body.writeline(line)
|
||||
return OpsValue(var)
|
||||
else:
|
||||
raise NotImplementedError(name)
|
||||
@ -138,6 +145,7 @@ class CutlassEVTCodegen(CutlassEVTOpsMixIn):
|
||||
self.cur_node: Optional[ComputedBuffer] = None
|
||||
self.name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs
|
||||
self.is_D_assigned = False
|
||||
self.D_var_name = None
|
||||
|
||||
if accumulator_node_name not in removed_buffers:
|
||||
# cannot return accumulator directly, so alias it
|
||||
@ -162,6 +170,8 @@ class CutlassEVTCodegen(CutlassEVTOpsMixIn):
|
||||
index_vars = CutlassEVTCodegen.get_index_vars(node)
|
||||
node.get_store_function()(index_vars)
|
||||
|
||||
codegen.finalize()
|
||||
|
||||
return (
|
||||
codegen.get_reads(),
|
||||
codegen.get_writes(),
|
||||
@ -178,6 +188,17 @@ class CutlassEVTCodegen(CutlassEVTOpsMixIn):
|
||||
]
|
||||
)
|
||||
|
||||
def finalize(self) -> None:
|
||||
# Rename the last store to D
|
||||
# no other code references this store
|
||||
# to workaround https://github.com/NVIDIA/cutlass/issues/2288
|
||||
# Note: the delayed line will automatically rewrite the last assignment to
|
||||
# be to D
|
||||
buffer_name = self.var_name_to_buffer_name[self.last_stored_var_name]
|
||||
self.var_name_to_buffer_name.pop(self.last_stored_var_name)
|
||||
self.var_name_to_buffer_name["D"] = buffer_name
|
||||
self.store_name_to_value[buffer_name] = OpsValue("D")
|
||||
|
||||
@contextmanager
|
||||
def set_cur_node(self, node: ComputedBuffer) -> Generator[None, Any, Any]:
|
||||
prev_node = self.cur_node
|
||||
@ -213,22 +234,12 @@ class CutlassEVTCodegen(CutlassEVTOpsMixIn):
|
||||
if name not in self.removed_buffers:
|
||||
if index:
|
||||
self._check_indexing(name, index)
|
||||
|
||||
value_to_write = value
|
||||
if not self.is_D_assigned and name != self.accumulator_node_name:
|
||||
# EVT requires an output to be named D lol
|
||||
# so rename the first store to D
|
||||
# accumulator cannot be assigned to D
|
||||
# see https://github.com/NVIDIA/cutlass/issues/2288
|
||||
self.body.writeline(f"D = {value} # cutlass evt requirement")
|
||||
value_to_write = OpsValue("D")
|
||||
self.is_D_assigned = True
|
||||
|
||||
assert value_to_write.value != _ACCUMULATOR_ARG_NAME, (
|
||||
assert value.value != _ACCUMULATOR_ARG_NAME, (
|
||||
"Cannot store accumulator arg name"
|
||||
)
|
||||
self.var_name_to_buffer_name[value_to_write.value] = name
|
||||
self.store_name_to_value[name] = value_to_write
|
||||
self.var_name_to_buffer_name[value.value] = name
|
||||
self.store_name_to_value[name] = value
|
||||
self.last_stored_var_name = value.value
|
||||
return None
|
||||
|
||||
def _get_cur_node(self) -> ComputedBuffer:
|
||||
|
Reference in New Issue
Block a user