mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add prologue fusion
This commit is contained in:
@ -255,6 +255,10 @@ prologue_fusion = prologue_fusion_enabled()
|
||||
# do epilogue fusions before other fusions
|
||||
epilogue_fusion_first = False
|
||||
|
||||
# enable custom op fusion support
|
||||
enable_custom_op_epilogue_fusion = os.environ.get("TORCHINDUCTOR_CUSTOM_OP_EPILOGUE_FUSION", "1") == "1"
|
||||
enable_custom_op_prologue_fusion = os.environ.get("TORCHINDUCTOR_CUSTOM_OP_PROLOGUE_FUSION", "1") == "1"
|
||||
|
||||
# enable pattern match+replace optimizations
|
||||
pattern_matcher = True
|
||||
|
||||
|
@ -178,6 +178,7 @@ def autotune_custom_op(
|
||||
dict[str, Callable[[torch.Tensor], torch.Tensor]]
|
||||
] = None,
|
||||
enable_epilogue_fusion: bool = False,
|
||||
enable_prologue_fusion: bool = False,
|
||||
) -> Union[TensorBox, Any]:
|
||||
"""Autotune custom operations by comparing multiple decomposition implementations.
|
||||
|
||||
@ -276,10 +277,13 @@ def autotune_custom_op(
|
||||
input_gen_fns=input_gen_fns,
|
||||
)
|
||||
|
||||
# Mark result for custom op epilogue fusion if enabled
|
||||
# Mark result for custom op fusion if enabled
|
||||
if enable_epilogue_fusion and isinstance(selected_result, TensorBox):
|
||||
_mark_custom_op_for_epilogue_fusion(selected_result, name)
|
||||
|
||||
if enable_prologue_fusion and isinstance(selected_result, TensorBox):
|
||||
_mark_custom_op_for_prologue_fusion(selected_result, name)
|
||||
|
||||
return selected_result
|
||||
|
||||
|
||||
@ -303,6 +307,26 @@ def _mark_custom_op_for_epilogue_fusion(result: TensorBox, name: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _mark_custom_op_for_prologue_fusion(result: TensorBox, name: str) -> None:
|
||||
"""Mark the result for custom op prologue fusion by the scheduler.
|
||||
|
||||
Args:
|
||||
result: The autotuning result to mark
|
||||
name: Operation name for identification
|
||||
"""
|
||||
if hasattr(result, "data") and hasattr(result.data, "get_name"):
|
||||
# Mark this buffer as a custom op result eligible for prologue fusion
|
||||
if not hasattr(result.data, "_custom_op_fusion_metadata"):
|
||||
result.data._custom_op_fusion_metadata = {}
|
||||
|
||||
result.data._custom_op_fusion_metadata.update(
|
||||
{
|
||||
"prologue_fusion_enabled": True,
|
||||
"custom_op_name": name,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def register_custom_op_autotuning(
|
||||
custom_op: torch._ops.OpOverload,
|
||||
decompositions: list[Callable[..., Any]],
|
||||
@ -310,6 +334,7 @@ def register_custom_op_autotuning(
|
||||
input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]] = None,
|
||||
tuning_knob: Optional[dict[str, list[Any]]] = None,
|
||||
enable_epilogue_fusion: bool = False,
|
||||
enable_prologue_fusion: bool = False,
|
||||
) -> None:
|
||||
"""Register custom operation for autotuning with multiple implementations.
|
||||
|
||||
@ -392,6 +417,7 @@ def register_custom_op_autotuning(
|
||||
default_impl=custom_op,
|
||||
user_input_gen_fns=input_gen_fns,
|
||||
enable_epilogue_fusion=enable_epilogue_fusion,
|
||||
enable_prologue_fusion=enable_prologue_fusion,
|
||||
)
|
||||
|
||||
validate_ir(result)
|
||||
|
@ -3217,9 +3217,21 @@ class Scheduler:
|
||||
and isinstance(n.get_template_node(), ir.MultiTemplateBuffer)
|
||||
for n in (node1, node2)
|
||||
)
|
||||
if not config.benchmark_fusion and not is_multi_template:
|
||||
# Check for custom op fusion cases
|
||||
is_custom_op_fusion = (
|
||||
self._can_apply_custom_op_epilogue_fusion(node1, node2) or
|
||||
self._can_apply_custom_op_prologue_fusion(node1, node2)
|
||||
)
|
||||
|
||||
if not config.benchmark_fusion and not is_multi_template and not is_custom_op_fusion:
|
||||
return True
|
||||
|
||||
# For custom op fusion, we want to benchmark by default unless explicitly disabled
|
||||
if is_custom_op_fusion and not config.benchmark_fusion:
|
||||
# Still allow benchmark for custom ops even if global benchmark_fusion is off
|
||||
fusion_log.debug("Benchmarking custom op fusion: %s <-> %s",
|
||||
node1.get_first_name(), node2.get_first_name())
|
||||
|
||||
if (
|
||||
node1.is_template()
|
||||
and not isinstance(node1.get_template_node(), ir.TritonTemplateBuffer)
|
||||
@ -4273,6 +4285,12 @@ class Scheduler:
|
||||
if not self.check_prologue_fusion_heuristics_fusable(node1, node2, why):
|
||||
return False
|
||||
|
||||
# Check for custom op prologue fusion
|
||||
if self._can_apply_custom_op_prologue_fusion(node1, node2):
|
||||
fusion_log.debug("Custom op prologue fusion applicable for %s -> %s",
|
||||
node1.get_first_name(), node2.get_first_name())
|
||||
return True
|
||||
|
||||
if node1.is_template() and (
|
||||
node2.has_aliasing_or_mutation()
|
||||
or node2.is_reduction()
|
||||
@ -4426,6 +4444,78 @@ class Scheduler:
|
||||
metadata.get('custom_op_name', 'unknown'))
|
||||
return True
|
||||
|
||||
# Enhanced check: also look for custom ops directly in the node
|
||||
if (hasattr(node1, 'node') and hasattr(node1.node, 'data') and
|
||||
hasattr(node1.node.data, 'name') and
|
||||
hasattr(node1.node.data, 'get_inputs')):
|
||||
|
||||
# Check if this is a result from our custom op autotune system
|
||||
if (hasattr(node1.node.data, 'get_name') and
|
||||
'_autotuned' in str(node1.node.data.get_name())):
|
||||
|
||||
# Apply similar checks as template epilogue fusion
|
||||
if (node2.is_pointwise() and
|
||||
not node2.is_reduction() and
|
||||
not node2.has_aliasing_or_mutation()):
|
||||
|
||||
fusion_log.debug("Custom op epilogue candidate: %s -> %s",
|
||||
node1.get_first_name(), node2.get_first_name())
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _can_apply_custom_op_prologue_fusion(
|
||||
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
||||
) -> bool:
|
||||
"""Check if custom op prologue fusion can be applied between two nodes.
|
||||
|
||||
Args:
|
||||
node1: Producer node (potential prologue operation)
|
||||
node2: Consumer node (potential custom op)
|
||||
|
||||
Returns:
|
||||
bool: True if custom op prologue fusion is applicable
|
||||
"""
|
||||
# Check if global config enables custom op prologue fusion
|
||||
from torch._inductor import config
|
||||
if not config.enable_custom_op_prologue_fusion:
|
||||
return False
|
||||
|
||||
# Check if node2 is marked as a custom op that supports prologue fusion
|
||||
if (hasattr(node2, 'node') and hasattr(node2.node, 'data') and
|
||||
hasattr(node2.node.data, '_custom_op_fusion_metadata')):
|
||||
|
||||
metadata = node2.node.data._custom_op_fusion_metadata
|
||||
if metadata.get('prologue_fusion_enabled', False):
|
||||
|
||||
# Check if node1 is a suitable prologue operation
|
||||
if (node1.is_pointwise() and
|
||||
not node1.is_reduction() and
|
||||
not node1.has_aliasing_or_mutation()):
|
||||
|
||||
fusion_log.info("Custom op prologue fusion enabled for %s -> %s (custom_op: %s)",
|
||||
node1.get_first_name(), node2.get_first_name(),
|
||||
metadata.get('custom_op_name', 'unknown'))
|
||||
return True
|
||||
|
||||
# Enhanced check: also look for custom ops directly in the node
|
||||
if (hasattr(node2, 'node') and hasattr(node2.node, 'data') and
|
||||
hasattr(node2.node.data, 'name') and
|
||||
hasattr(node2.node.data, 'get_inputs')):
|
||||
|
||||
# Check if this is a result from our custom op autotune system
|
||||
if (hasattr(node2.node.data, 'get_name') and
|
||||
'_autotuned' in str(node2.node.data.get_name())):
|
||||
|
||||
# Apply similar checks as template prologue fusion
|
||||
if (node1.is_pointwise() and
|
||||
not node1.is_reduction() and
|
||||
not node1.has_aliasing_or_mutation()):
|
||||
|
||||
fusion_log.debug("Custom op prologue candidate: %s -> %s",
|
||||
node1.get_first_name(), node2.get_first_name())
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def fusable_weak_dep(
|
||||
|
Reference in New Issue
Block a user