This commit is contained in:
Tianren Gao
2025-10-20 17:49:33 -07:00
parent 61cf16bfb3
commit 4aa88aff67
4 changed files with 18 additions and 264 deletions

View File

@ -460,10 +460,7 @@ class TestCustomOpAutoTune(TestCase):
):
compiled_result = epilogue_fusion_model(a, b, bias)
# Since we disabled fallback, autotune must select one of the decompose_k variants
# Compare against the same decompose_k implementation with k_splits=32 (first option)
def reference_model(a, b, bias):
# Use decompose_k with k_splits=32 (first value from tuning_knob)
matmul_result = decompose_k_implementation(a, b, k_splits=32)
biased = matmul_result + bias
activated = torch.relu(biased)
@ -472,35 +469,11 @@ class TestCustomOpAutoTune(TestCase):
expected_final = reference_model(a, b, bias)
# Debug: Check actual differences to understand why it's failing
abs_diff = torch.abs(compiled_result - expected_final)
rel_diff = abs_diff / (torch.abs(expected_final) + 1e-8)
max_abs_diff = torch.max(abs_diff).item()
max_rel_diff = torch.max(rel_diff).item()
mean_abs_diff = torch.mean(abs_diff).item()
print("🔍 Numerical difference debug:")
print(f" Max absolute difference: {max_abs_diff:.8f}")
print(f" Max relative difference: {max_rel_diff:.8f}")
print(f" Mean absolute difference: {mean_abs_diff:.8f}")
print(
f" Compiled result range: [{torch.min(compiled_result):.6f}, {torch.max(compiled_result):.6f}]"
)
print(
f" Expected result range: [{torch.min(expected_final):.6f}, {torch.max(expected_final):.6f}]"
)
rtol, atol = 1, 1
print(f" Using tolerance: rtol={rtol}, atol={atol}")
torch.testing.assert_close(
compiled_result,
expected_final,
rtol=rtol,
atol=atol,
msg=f"Decompose-k epilogue fusion numerical mismatch (max_abs_diff={max_abs_diff:.8f}, max_rel_diff={max_rel_diff:.8f})",
rtol=1e-2,
atol=1e-2,
)
@skipIfXpu

View File

@ -256,17 +256,12 @@ prologue_fusion = prologue_fusion_enabled()
epilogue_fusion_first = False
# enable custom op fusion support
<<<<<<< HEAD
enable_custom_op_epilogue_fusion = (
os.environ.get("TORCHINDUCTOR_CUSTOM_OP_EPILOGUE_FUSION", "1") == "1"
)
enable_custom_op_prologue_fusion = (
os.environ.get("TORCHINDUCTOR_CUSTOM_OP_PROLOGUE_FUSION", "1") == "1"
)
=======
enable_custom_op_epilogue_fusion = os.environ.get("TORCHINDUCTOR_CUSTOM_OP_EPILOGUE_FUSION", "1") == "1"
enable_custom_op_prologue_fusion = os.environ.get("TORCHINDUCTOR_CUSTOM_OP_PROLOGUE_FUSION", "1") == "1"
>>>>>>> fb09f17c632 (add prologue fusion)
# enable pattern match+replace optimizations
pattern_matcher = True

View File

@ -387,60 +387,16 @@ def autotune_custom_op(
# Apply inlining if epilogue fusion is enabled
if enable_epilogue_fusion and isinstance(selected_result, TensorBox):
# Find the winning choice that was selected during autotuning
winning_choice = None
# Debug: Let's understand the structure of selected_result
print(f"🔍 Debugging selected_result: {type(selected_result)}")
print(f"🔍 selected_result.data: {type(selected_result.data)}")
if hasattr(selected_result.data, "__dict__"):
print(
f"🔍 selected_result.data attributes: {list(selected_result.data.__dict__.keys())}"
)
# Try different ways to find the winning choice
if hasattr(selected_result, "data") and hasattr(
selected_result.data, "subgraph_name"
):
# SubgraphBuffer case - find matching choice by name
subgraph_name = selected_result.data.subgraph_name
print(f"🔍 Looking for subgraph_name: {subgraph_name}")
for choice in choices:
print(f"🔍 Choice name: {choice.name}")
if choice.name == subgraph_name:
winning_choice = choice
break
# Alternative: The first choice might be the winner if we can't find exact match
if not winning_choice and choices:
print(f"🔍 Using first choice as fallback: {choices[0].name}")
winning_choice = choices[0]
if winning_choice:
print(f"🎯 Inlining winning choice: {winning_choice.name}")
try:
# Inline the winning choice operations into the main graph
inlined_result = _inline_custom_op_choice(winning_choice, inputs, name)
return inlined_result
except Exception as e:
print(f"❌ Inlining failed: {e}")
print("⚠️ Falling back to marking approach")
else:
print(
"⚠️ Could not find winning choice for inlining, falling back to marking"
)
# Mark result for custom op fusion if enabled (fallback path)
if enable_epilogue_fusion and isinstance(selected_result, TensorBox):
_mark_custom_op_for_epilogue_fusion(selected_result, name)
if enable_prologue_fusion and isinstance(selected_result, TensorBox):
_mark_custom_op_for_prologue_fusion(selected_result, name)
winning_choice = choices[0] # TODO: fix use selected choice instead of 0
inlined_result = _inline_custom_op_choice(winning_choice, inputs, name)
return inlined_result
return selected_result
def _inline_custom_op_choice(winning_choice, inputs: list[Any], name: str) -> TensorBox:
def _inline_custom_op_choice(
winning_choice: Any, inputs: list[Any], name: str
) -> TensorBox:
"""Inline the winning custom op choice by converting its FX operations to individual IR nodes.
This converts the custom op from a single ExternKernel (unfusable) to multiple ComputedBuffer
@ -459,10 +415,6 @@ def _inline_custom_op_choice(winning_choice, inputs: list[Any], name: str) -> Te
# Get the GraphModule containing the operations
gm = winning_choice.gm
# Create a temporary graph lowering context to process the FX nodes
# We'll extract the operations and add them to the current graph
current_graph = V.graph
# Create mapping from placeholder nodes to actual inputs
node_to_value = {}
placeholder_idx = 0
@ -516,46 +468,6 @@ def _inline_custom_op_choice(winning_choice, inputs: list[Any], name: str) -> Te
raise RuntimeError("No output node found in custom op graph")
def _mark_custom_op_for_epilogue_fusion(result: TensorBox, name: str) -> None:
"""Mark the result for custom op epilogue fusion by the scheduler.
Args:
result: The autotuning result to mark
name: Operation name for identification
"""
if hasattr(result, "data") and hasattr(result.data, "get_name"):
# Mark this buffer as a custom op result eligible for epilogue fusion
if not hasattr(result.data, "_custom_op_fusion_metadata"):
result.data._custom_op_fusion_metadata = {}
result.data._custom_op_fusion_metadata.update(
{
"epilogue_fusion_enabled": True,
"custom_op_name": name,
}
)
def _mark_custom_op_for_prologue_fusion(result: TensorBox, name: str) -> None:
"""Mark the result for custom op prologue fusion by the scheduler.
Args:
result: The autotuning result to mark
name: Operation name for identification
"""
if hasattr(result, "data") and hasattr(result.data, "get_name"):
# Mark this buffer as a custom op result eligible for prologue fusion
if not hasattr(result.data, "_custom_op_fusion_metadata"):
result.data._custom_op_fusion_metadata = {}
result.data._custom_op_fusion_metadata.update(
{
"prologue_fusion_enabled": True,
"custom_op_name": name,
}
)
def register_custom_op_autotuning(
custom_op: torch._ops.OpOverload,
configs: Union[list[CustomOpConfig], list[Callable[..., Any]]],

View File

@ -187,6 +187,16 @@ class SchedulerBuffer:
input_buffer,
self.node,
)
else:
V.graph.wrapper_code.codegen_allocation(self.node)
def can_free(self) -> bool:
# There's no real allocated buffer, no need to free it
assert self.node is not None
if isinstance(self.node.layout, ir.NoneLayout) or is_multi_outputs_template(
self.node
):
return False
for use in self.users:
if isinstance(use.node, OutputNode):
return False
@ -3242,20 +3252,6 @@ class Scheduler:
and isinstance(n.get_template_node(), ir.MultiTemplateBuffer)
for n in (node1, node2)
)
# Check for custom op fusion cases
is_custom_op_fusion = (
self._can_apply_custom_op_epilogue_fusion(node1, node2) or
self._can_apply_custom_op_prologue_fusion(node1, node2)
)
if not config.benchmark_fusion and not is_multi_template and not is_custom_op_fusion:
return True
# For custom op fusion, we want to benchmark by default unless explicitly disabled
if is_custom_op_fusion and not config.benchmark_fusion:
# Still allow benchmark for custom ops even if global benchmark_fusion is off
fusion_log.debug("Benchmarking custom op fusion: %s <-> %s",
node1.get_first_name(), node2.get_first_name())
if (
node1.is_template()
@ -4328,12 +4324,6 @@ class Scheduler:
if not self.check_prologue_fusion_heuristics_fusable(node1, node2, why):
return False
# Check for custom op prologue fusion
if self._can_apply_custom_op_prologue_fusion(node1, node2):
fusion_log.debug("Custom op prologue fusion applicable for %s -> %s",
node1.get_first_name(), node2.get_first_name())
return True
if node1.is_template() and (
node2.has_aliasing_or_mutation()
or node2.is_reduction()
@ -4404,14 +4394,6 @@ class Scheduler:
corresponding writes in node1, or are written by nodes that can
be scheduled before the fusion of node1 and node2.
"""
<<<<<<< HEAD
=======
# Check for custom op epilogue fusion opportunities first
if self._can_apply_custom_op_epilogue_fusion(node1, node2):
fusion_log.debug("Custom op epilogue fusion applicable for %s -> %s",
node1.get_first_name(), node2.get_first_name())
return True
>>>>>>> 7baadf1678a (initial code)
node1_buf_names = node1.get_buffer_names()
why = WhyNoFuse(node1, node2)
@ -4456,114 +4438,6 @@ class Scheduler:
return True
def _can_apply_custom_op_epilogue_fusion(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
) -> bool:
"""Check if custom op epilogue fusion can be applied between two nodes.
Args:
node1: Producer node (potential custom op result)
node2: Consumer node (potential epilogue operation)
Returns:
bool: True if custom op epilogue fusion is applicable
"""
# Check if global config enables custom op epilogue fusion
from torch._inductor import config
if not config.enable_custom_op_epilogue_fusion:
return False
# Check if node1 is marked as a custom op result eligible for epilogue fusion
if (hasattr(node1, 'node') and hasattr(node1.node, 'data') and
hasattr(node1.node.data, '_custom_op_fusion_metadata')):
metadata = node1.node.data._custom_op_fusion_metadata
if metadata.get('epilogue_fusion_enabled', False):
# Check if node2 is a suitable epilogue operation
if (node2.is_pointwise() and
not node2.is_reduction() and
not node2.has_aliasing_or_mutation()):
fusion_log.info("Custom op epilogue fusion enabled for %s -> %s (custom_op: %s)",
node1.get_first_name(), node2.get_first_name(),
metadata.get('custom_op_name', 'unknown'))
return True
# Enhanced check: also look for custom ops directly in the node
if (hasattr(node1, 'node') and hasattr(node1.node, 'data') and
hasattr(node1.node.data, 'name') and
hasattr(node1.node.data, 'get_inputs')):
# Check if this is a result from our custom op autotune system
if (hasattr(node1.node.data, 'get_name') and
'_autotuned' in str(node1.node.data.get_name())):
# Apply similar checks as template epilogue fusion
if (node2.is_pointwise() and
not node2.is_reduction() and
not node2.has_aliasing_or_mutation()):
fusion_log.debug("Custom op epilogue candidate: %s -> %s",
node1.get_first_name(), node2.get_first_name())
return True
return False
def _can_apply_custom_op_prologue_fusion(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
) -> bool:
"""Check if custom op prologue fusion can be applied between two nodes.
Args:
node1: Producer node (potential prologue operation)
node2: Consumer node (potential custom op)
Returns:
bool: True if custom op prologue fusion is applicable
"""
# Check if global config enables custom op prologue fusion
from torch._inductor import config
if not config.enable_custom_op_prologue_fusion:
return False
# Check if node2 is marked as a custom op that supports prologue fusion
if (hasattr(node2, 'node') and hasattr(node2.node, 'data') and
hasattr(node2.node.data, '_custom_op_fusion_metadata')):
metadata = node2.node.data._custom_op_fusion_metadata
if metadata.get('prologue_fusion_enabled', False):
# Check if node1 is a suitable prologue operation
if (node1.is_pointwise() and
not node1.is_reduction() and
not node1.has_aliasing_or_mutation()):
fusion_log.info("Custom op prologue fusion enabled for %s -> %s (custom_op: %s)",
node1.get_first_name(), node2.get_first_name(),
metadata.get('custom_op_name', 'unknown'))
return True
# Enhanced check: also look for custom ops directly in the node
if (hasattr(node2, 'node') and hasattr(node2.node, 'data') and
hasattr(node2.node.data, 'name') and
hasattr(node2.node.data, 'get_inputs')):
# Check if this is a result from our custom op autotune system
if (hasattr(node2.node.data, 'get_name') and
'_autotuned' in str(node2.node.data.get_name())):
# Apply similar checks as template prologue fusion
if (node1.is_pointwise() and
not node1.is_reduction() and
not node1.has_aliasing_or_mutation()):
fusion_log.debug("Custom op prologue candidate: %s -> %s",
node1.get_first_name(), node2.get_first_name())
return True
return False
def fusable_weak_dep(
self, weak_dep: WeakDep, node1: BaseSchedulerNode, node2: BaseSchedulerNode
) -> bool: