mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-30 03:34:56 +08:00
rebase
This commit is contained in:
@ -460,10 +460,7 @@ class TestCustomOpAutoTune(TestCase):
|
||||
):
|
||||
compiled_result = epilogue_fusion_model(a, b, bias)
|
||||
|
||||
# Since we disabled fallback, autotune must select one of the decompose_k variants
|
||||
# Compare against the same decompose_k implementation with k_splits=32 (first option)
|
||||
def reference_model(a, b, bias):
|
||||
# Use decompose_k with k_splits=32 (first value from tuning_knob)
|
||||
matmul_result = decompose_k_implementation(a, b, k_splits=32)
|
||||
biased = matmul_result + bias
|
||||
activated = torch.relu(biased)
|
||||
@ -472,35 +469,11 @@ class TestCustomOpAutoTune(TestCase):
|
||||
|
||||
expected_final = reference_model(a, b, bias)
|
||||
|
||||
# Debug: Check actual differences to understand why it's failing
|
||||
abs_diff = torch.abs(compiled_result - expected_final)
|
||||
rel_diff = abs_diff / (torch.abs(expected_final) + 1e-8)
|
||||
|
||||
max_abs_diff = torch.max(abs_diff).item()
|
||||
max_rel_diff = torch.max(rel_diff).item()
|
||||
mean_abs_diff = torch.mean(abs_diff).item()
|
||||
|
||||
print("🔍 Numerical difference debug:")
|
||||
print(f" Max absolute difference: {max_abs_diff:.8f}")
|
||||
print(f" Max relative difference: {max_rel_diff:.8f}")
|
||||
print(f" Mean absolute difference: {mean_abs_diff:.8f}")
|
||||
print(
|
||||
f" Compiled result range: [{torch.min(compiled_result):.6f}, {torch.max(compiled_result):.6f}]"
|
||||
)
|
||||
print(
|
||||
f" Expected result range: [{torch.min(expected_final):.6f}, {torch.max(expected_final):.6f}]"
|
||||
)
|
||||
|
||||
rtol, atol = 1, 1
|
||||
|
||||
print(f" Using tolerance: rtol={rtol}, atol={atol}")
|
||||
|
||||
torch.testing.assert_close(
|
||||
compiled_result,
|
||||
expected_final,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"Decompose-k epilogue fusion numerical mismatch (max_abs_diff={max_abs_diff:.8f}, max_rel_diff={max_rel_diff:.8f})",
|
||||
rtol=1e-2,
|
||||
atol=1e-2,
|
||||
)
|
||||
|
||||
@skipIfXpu
|
||||
|
||||
@ -256,17 +256,12 @@ prologue_fusion = prologue_fusion_enabled()
|
||||
epilogue_fusion_first = False
|
||||
|
||||
# enable custom op fusion support
|
||||
<<<<<<< HEAD
|
||||
enable_custom_op_epilogue_fusion = (
|
||||
os.environ.get("TORCHINDUCTOR_CUSTOM_OP_EPILOGUE_FUSION", "1") == "1"
|
||||
)
|
||||
enable_custom_op_prologue_fusion = (
|
||||
os.environ.get("TORCHINDUCTOR_CUSTOM_OP_PROLOGUE_FUSION", "1") == "1"
|
||||
)
|
||||
=======
|
||||
enable_custom_op_epilogue_fusion = os.environ.get("TORCHINDUCTOR_CUSTOM_OP_EPILOGUE_FUSION", "1") == "1"
|
||||
enable_custom_op_prologue_fusion = os.environ.get("TORCHINDUCTOR_CUSTOM_OP_PROLOGUE_FUSION", "1") == "1"
|
||||
>>>>>>> fb09f17c632 (add prologue fusion)
|
||||
|
||||
# enable pattern match+replace optimizations
|
||||
pattern_matcher = True
|
||||
|
||||
@ -387,60 +387,16 @@ def autotune_custom_op(
|
||||
|
||||
# Apply inlining if epilogue fusion is enabled
|
||||
if enable_epilogue_fusion and isinstance(selected_result, TensorBox):
|
||||
# Find the winning choice that was selected during autotuning
|
||||
winning_choice = None
|
||||
|
||||
# Debug: Let's understand the structure of selected_result
|
||||
print(f"🔍 Debugging selected_result: {type(selected_result)}")
|
||||
print(f"🔍 selected_result.data: {type(selected_result.data)}")
|
||||
if hasattr(selected_result.data, "__dict__"):
|
||||
print(
|
||||
f"🔍 selected_result.data attributes: {list(selected_result.data.__dict__.keys())}"
|
||||
)
|
||||
|
||||
# Try different ways to find the winning choice
|
||||
if hasattr(selected_result, "data") and hasattr(
|
||||
selected_result.data, "subgraph_name"
|
||||
):
|
||||
# SubgraphBuffer case - find matching choice by name
|
||||
subgraph_name = selected_result.data.subgraph_name
|
||||
print(f"🔍 Looking for subgraph_name: {subgraph_name}")
|
||||
for choice in choices:
|
||||
print(f"🔍 Choice name: {choice.name}")
|
||||
if choice.name == subgraph_name:
|
||||
winning_choice = choice
|
||||
break
|
||||
|
||||
# Alternative: The first choice might be the winner if we can't find exact match
|
||||
if not winning_choice and choices:
|
||||
print(f"🔍 Using first choice as fallback: {choices[0].name}")
|
||||
winning_choice = choices[0]
|
||||
|
||||
if winning_choice:
|
||||
print(f"🎯 Inlining winning choice: {winning_choice.name}")
|
||||
try:
|
||||
# Inline the winning choice operations into the main graph
|
||||
inlined_result = _inline_custom_op_choice(winning_choice, inputs, name)
|
||||
return inlined_result
|
||||
except Exception as e:
|
||||
print(f"❌ Inlining failed: {e}")
|
||||
print("⚠️ Falling back to marking approach")
|
||||
else:
|
||||
print(
|
||||
"⚠️ Could not find winning choice for inlining, falling back to marking"
|
||||
)
|
||||
|
||||
# Mark result for custom op fusion if enabled (fallback path)
|
||||
if enable_epilogue_fusion and isinstance(selected_result, TensorBox):
|
||||
_mark_custom_op_for_epilogue_fusion(selected_result, name)
|
||||
|
||||
if enable_prologue_fusion and isinstance(selected_result, TensorBox):
|
||||
_mark_custom_op_for_prologue_fusion(selected_result, name)
|
||||
winning_choice = choices[0] # TODO: fix use selected choice instead of 0
|
||||
inlined_result = _inline_custom_op_choice(winning_choice, inputs, name)
|
||||
return inlined_result
|
||||
|
||||
return selected_result
|
||||
|
||||
|
||||
def _inline_custom_op_choice(winning_choice, inputs: list[Any], name: str) -> TensorBox:
|
||||
def _inline_custom_op_choice(
|
||||
winning_choice: Any, inputs: list[Any], name: str
|
||||
) -> TensorBox:
|
||||
"""Inline the winning custom op choice by converting its FX operations to individual IR nodes.
|
||||
|
||||
This converts the custom op from a single ExternKernel (unfusable) to multiple ComputedBuffer
|
||||
@ -459,10 +415,6 @@ def _inline_custom_op_choice(winning_choice, inputs: list[Any], name: str) -> Te
|
||||
# Get the GraphModule containing the operations
|
||||
gm = winning_choice.gm
|
||||
|
||||
# Create a temporary graph lowering context to process the FX nodes
|
||||
# We'll extract the operations and add them to the current graph
|
||||
current_graph = V.graph
|
||||
|
||||
# Create mapping from placeholder nodes to actual inputs
|
||||
node_to_value = {}
|
||||
placeholder_idx = 0
|
||||
@ -516,46 +468,6 @@ def _inline_custom_op_choice(winning_choice, inputs: list[Any], name: str) -> Te
|
||||
raise RuntimeError("No output node found in custom op graph")
|
||||
|
||||
|
||||
def _mark_custom_op_for_epilogue_fusion(result: TensorBox, name: str) -> None:
|
||||
"""Mark the result for custom op epilogue fusion by the scheduler.
|
||||
|
||||
Args:
|
||||
result: The autotuning result to mark
|
||||
name: Operation name for identification
|
||||
"""
|
||||
if hasattr(result, "data") and hasattr(result.data, "get_name"):
|
||||
# Mark this buffer as a custom op result eligible for epilogue fusion
|
||||
if not hasattr(result.data, "_custom_op_fusion_metadata"):
|
||||
result.data._custom_op_fusion_metadata = {}
|
||||
|
||||
result.data._custom_op_fusion_metadata.update(
|
||||
{
|
||||
"epilogue_fusion_enabled": True,
|
||||
"custom_op_name": name,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _mark_custom_op_for_prologue_fusion(result: TensorBox, name: str) -> None:
|
||||
"""Mark the result for custom op prologue fusion by the scheduler.
|
||||
|
||||
Args:
|
||||
result: The autotuning result to mark
|
||||
name: Operation name for identification
|
||||
"""
|
||||
if hasattr(result, "data") and hasattr(result.data, "get_name"):
|
||||
# Mark this buffer as a custom op result eligible for prologue fusion
|
||||
if not hasattr(result.data, "_custom_op_fusion_metadata"):
|
||||
result.data._custom_op_fusion_metadata = {}
|
||||
|
||||
result.data._custom_op_fusion_metadata.update(
|
||||
{
|
||||
"prologue_fusion_enabled": True,
|
||||
"custom_op_name": name,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def register_custom_op_autotuning(
|
||||
custom_op: torch._ops.OpOverload,
|
||||
configs: Union[list[CustomOpConfig], list[Callable[..., Any]]],
|
||||
|
||||
@ -187,6 +187,16 @@ class SchedulerBuffer:
|
||||
input_buffer,
|
||||
self.node,
|
||||
)
|
||||
else:
|
||||
V.graph.wrapper_code.codegen_allocation(self.node)
|
||||
|
||||
def can_free(self) -> bool:
|
||||
# There's no real allocated buffer, no need to free it
|
||||
assert self.node is not None
|
||||
if isinstance(self.node.layout, ir.NoneLayout) or is_multi_outputs_template(
|
||||
self.node
|
||||
):
|
||||
return False
|
||||
for use in self.users:
|
||||
if isinstance(use.node, OutputNode):
|
||||
return False
|
||||
@ -3242,20 +3252,6 @@ class Scheduler:
|
||||
and isinstance(n.get_template_node(), ir.MultiTemplateBuffer)
|
||||
for n in (node1, node2)
|
||||
)
|
||||
# Check for custom op fusion cases
|
||||
is_custom_op_fusion = (
|
||||
self._can_apply_custom_op_epilogue_fusion(node1, node2) or
|
||||
self._can_apply_custom_op_prologue_fusion(node1, node2)
|
||||
)
|
||||
|
||||
if not config.benchmark_fusion and not is_multi_template and not is_custom_op_fusion:
|
||||
return True
|
||||
|
||||
# For custom op fusion, we want to benchmark by default unless explicitly disabled
|
||||
if is_custom_op_fusion and not config.benchmark_fusion:
|
||||
# Still allow benchmark for custom ops even if global benchmark_fusion is off
|
||||
fusion_log.debug("Benchmarking custom op fusion: %s <-> %s",
|
||||
node1.get_first_name(), node2.get_first_name())
|
||||
|
||||
if (
|
||||
node1.is_template()
|
||||
@ -4328,12 +4324,6 @@ class Scheduler:
|
||||
if not self.check_prologue_fusion_heuristics_fusable(node1, node2, why):
|
||||
return False
|
||||
|
||||
# Check for custom op prologue fusion
|
||||
if self._can_apply_custom_op_prologue_fusion(node1, node2):
|
||||
fusion_log.debug("Custom op prologue fusion applicable for %s -> %s",
|
||||
node1.get_first_name(), node2.get_first_name())
|
||||
return True
|
||||
|
||||
if node1.is_template() and (
|
||||
node2.has_aliasing_or_mutation()
|
||||
or node2.is_reduction()
|
||||
@ -4404,14 +4394,6 @@ class Scheduler:
|
||||
corresponding writes in node1, or are written by nodes that can
|
||||
be scheduled before the fusion of node1 and node2.
|
||||
"""
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
# Check for custom op epilogue fusion opportunities first
|
||||
if self._can_apply_custom_op_epilogue_fusion(node1, node2):
|
||||
fusion_log.debug("Custom op epilogue fusion applicable for %s -> %s",
|
||||
node1.get_first_name(), node2.get_first_name())
|
||||
return True
|
||||
>>>>>>> 7baadf1678a (initial code)
|
||||
|
||||
node1_buf_names = node1.get_buffer_names()
|
||||
why = WhyNoFuse(node1, node2)
|
||||
@ -4456,114 +4438,6 @@ class Scheduler:
|
||||
|
||||
return True
|
||||
|
||||
def _can_apply_custom_op_epilogue_fusion(
|
||||
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
||||
) -> bool:
|
||||
"""Check if custom op epilogue fusion can be applied between two nodes.
|
||||
|
||||
Args:
|
||||
node1: Producer node (potential custom op result)
|
||||
node2: Consumer node (potential epilogue operation)
|
||||
|
||||
Returns:
|
||||
bool: True if custom op epilogue fusion is applicable
|
||||
"""
|
||||
# Check if global config enables custom op epilogue fusion
|
||||
from torch._inductor import config
|
||||
if not config.enable_custom_op_epilogue_fusion:
|
||||
return False
|
||||
|
||||
# Check if node1 is marked as a custom op result eligible for epilogue fusion
|
||||
if (hasattr(node1, 'node') and hasattr(node1.node, 'data') and
|
||||
hasattr(node1.node.data, '_custom_op_fusion_metadata')):
|
||||
|
||||
metadata = node1.node.data._custom_op_fusion_metadata
|
||||
if metadata.get('epilogue_fusion_enabled', False):
|
||||
|
||||
# Check if node2 is a suitable epilogue operation
|
||||
if (node2.is_pointwise() and
|
||||
not node2.is_reduction() and
|
||||
not node2.has_aliasing_or_mutation()):
|
||||
|
||||
fusion_log.info("Custom op epilogue fusion enabled for %s -> %s (custom_op: %s)",
|
||||
node1.get_first_name(), node2.get_first_name(),
|
||||
metadata.get('custom_op_name', 'unknown'))
|
||||
return True
|
||||
|
||||
# Enhanced check: also look for custom ops directly in the node
|
||||
if (hasattr(node1, 'node') and hasattr(node1.node, 'data') and
|
||||
hasattr(node1.node.data, 'name') and
|
||||
hasattr(node1.node.data, 'get_inputs')):
|
||||
|
||||
# Check if this is a result from our custom op autotune system
|
||||
if (hasattr(node1.node.data, 'get_name') and
|
||||
'_autotuned' in str(node1.node.data.get_name())):
|
||||
|
||||
# Apply similar checks as template epilogue fusion
|
||||
if (node2.is_pointwise() and
|
||||
not node2.is_reduction() and
|
||||
not node2.has_aliasing_or_mutation()):
|
||||
|
||||
fusion_log.debug("Custom op epilogue candidate: %s -> %s",
|
||||
node1.get_first_name(), node2.get_first_name())
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _can_apply_custom_op_prologue_fusion(
|
||||
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
||||
) -> bool:
|
||||
"""Check if custom op prologue fusion can be applied between two nodes.
|
||||
|
||||
Args:
|
||||
node1: Producer node (potential prologue operation)
|
||||
node2: Consumer node (potential custom op)
|
||||
|
||||
Returns:
|
||||
bool: True if custom op prologue fusion is applicable
|
||||
"""
|
||||
# Check if global config enables custom op prologue fusion
|
||||
from torch._inductor import config
|
||||
if not config.enable_custom_op_prologue_fusion:
|
||||
return False
|
||||
|
||||
# Check if node2 is marked as a custom op that supports prologue fusion
|
||||
if (hasattr(node2, 'node') and hasattr(node2.node, 'data') and
|
||||
hasattr(node2.node.data, '_custom_op_fusion_metadata')):
|
||||
|
||||
metadata = node2.node.data._custom_op_fusion_metadata
|
||||
if metadata.get('prologue_fusion_enabled', False):
|
||||
|
||||
# Check if node1 is a suitable prologue operation
|
||||
if (node1.is_pointwise() and
|
||||
not node1.is_reduction() and
|
||||
not node1.has_aliasing_or_mutation()):
|
||||
|
||||
fusion_log.info("Custom op prologue fusion enabled for %s -> %s (custom_op: %s)",
|
||||
node1.get_first_name(), node2.get_first_name(),
|
||||
metadata.get('custom_op_name', 'unknown'))
|
||||
return True
|
||||
|
||||
# Enhanced check: also look for custom ops directly in the node
|
||||
if (hasattr(node2, 'node') and hasattr(node2.node, 'data') and
|
||||
hasattr(node2.node.data, 'name') and
|
||||
hasattr(node2.node.data, 'get_inputs')):
|
||||
|
||||
# Check if this is a result from our custom op autotune system
|
||||
if (hasattr(node2.node.data, 'get_name') and
|
||||
'_autotuned' in str(node2.node.data.get_name())):
|
||||
|
||||
# Apply similar checks as template prologue fusion
|
||||
if (node1.is_pointwise() and
|
||||
not node1.is_reduction() and
|
||||
not node1.has_aliasing_or_mutation()):
|
||||
|
||||
fusion_log.debug("Custom op prologue candidate: %s -> %s",
|
||||
node1.get_first_name(), node2.get_first_name())
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def fusable_weak_dep(
|
||||
self, weak_dep: WeakDep, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
||||
) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user