fix typo and lint

This commit is contained in:
Tianren Gao
2025-10-19 16:52:56 -07:00
parent 3c28278dff
commit 23cba2d68c
3 changed files with 113 additions and 64 deletions

View File

@ -256,8 +256,12 @@ prologue_fusion = prologue_fusion_enabled()
epilogue_fusion_first = False epilogue_fusion_first = False
# enable custom op fusion support # enable custom op fusion support
enable_custom_op_epilogue_fusion = os.environ.get("TORCHINDUCTOR_CUSTOM_OP_EPILOGUE_FUSION", "1") == "1" enable_custom_op_epilogue_fusion = (
enable_custom_op_prologue_fusion = os.environ.get("TORCHINDUCTOR_CUSTOM_OP_PROLOGUE_FUSION", "1") == "1" os.environ.get("TORCHINDUCTOR_CUSTOM_OP_EPILOGUE_FUSION", "1") == "1"
)
enable_custom_op_prologue_fusion = (
os.environ.get("TORCHINDUCTOR_CUSTOM_OP_PROLOGUE_FUSION", "1") == "1"
)
# enable pattern match+replace optimizations # enable pattern match+replace optimizations
pattern_matcher = True pattern_matcher = True

View File

@ -65,7 +65,6 @@ def _create_user_input_gen_fns(
Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes. Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes.
""" """
from torch._inductor import config
name_to_index = {name: i for i, name in enumerate(arg_names)} name_to_index = {name: i for i, name in enumerate(arg_names)}
index_based_fns = {} index_based_fns = {}

View File

@ -187,6 +187,16 @@ class SchedulerBuffer:
input_buffer, input_buffer,
self.node, self.node,
) )
else:
V.graph.wrapper_code.codegen_allocation(self.node)
def can_free(self) -> bool:
# There's no real allocated buffer, no need to free it
assert self.node is not None
if isinstance(self.node.layout, ir.NoneLayout) or is_multi_outputs_template(
self.node
):
return False
for use in self.users: for use in self.users:
if isinstance(use.node, OutputNode): if isinstance(use.node, OutputNode):
return False return False
@ -3218,19 +3228,25 @@ class Scheduler:
for n in (node1, node2) for n in (node1, node2)
) )
# Check for custom op fusion cases # Check for custom op fusion cases
is_custom_op_fusion = ( is_custom_op_fusion = self._can_apply_custom_op_epilogue_fusion(
self._can_apply_custom_op_epilogue_fusion(node1, node2) or node1, node2
self._can_apply_custom_op_prologue_fusion(node1, node2) ) or self._can_apply_custom_op_prologue_fusion(node1, node2)
)
if not config.benchmark_fusion and not is_multi_template and not is_custom_op_fusion: if (
not config.benchmark_fusion
and not is_multi_template
and not is_custom_op_fusion
):
return True return True
# For custom op fusion, we want to benchmark by default unless explicitly disabled # For custom op fusion, we want to benchmark by default unless explicitly disabled
if is_custom_op_fusion and not config.benchmark_fusion: if is_custom_op_fusion and not config.benchmark_fusion:
# Still allow benchmark for custom ops even if global benchmark_fusion is off # Still allow benchmark for custom ops even if global benchmark_fusion is off
fusion_log.debug("Benchmarking custom op fusion: %s <-> %s", fusion_log.debug(
node1.get_first_name(), node2.get_first_name()) "Benchmarking custom op fusion: %s <-> %s",
node1.get_first_name(),
node2.get_first_name(),
)
if ( if (
node1.is_template() node1.is_template()
@ -4287,8 +4303,11 @@ class Scheduler:
# Check for custom op prologue fusion # Check for custom op prologue fusion
if self._can_apply_custom_op_prologue_fusion(node1, node2): if self._can_apply_custom_op_prologue_fusion(node1, node2):
fusion_log.debug("Custom op prologue fusion applicable for %s -> %s", fusion_log.debug(
node1.get_first_name(), node2.get_first_name()) "Custom op prologue fusion applicable for %s -> %s",
node1.get_first_name(),
node2.get_first_name(),
)
return True return True
if node1.is_template() and ( if node1.is_template() and (
@ -4363,8 +4382,11 @@ class Scheduler:
""" """
# Check for custom op epilogue fusion opportunities first # Check for custom op epilogue fusion opportunities first
if self._can_apply_custom_op_epilogue_fusion(node1, node2): if self._can_apply_custom_op_epilogue_fusion(node1, node2):
fusion_log.debug("Custom op epilogue fusion applicable for %s -> %s", fusion_log.debug(
node1.get_first_name(), node2.get_first_name()) "Custom op epilogue fusion applicable for %s -> %s",
node1.get_first_name(),
node2.get_first_name(),
)
return True return True
node1_buf_names = node1.get_buffer_names() node1_buf_names = node1.get_buffer_names()
@ -4424,42 +4446,54 @@ class Scheduler:
""" """
# Check if global config enables custom op epilogue fusion # Check if global config enables custom op epilogue fusion
from torch._inductor import config from torch._inductor import config
if not config.enable_custom_op_epilogue_fusion: if not config.enable_custom_op_epilogue_fusion:
return False return False
# Check if node1 is marked as a custom op result eligible for epilogue fusion # Check if node1 is marked as a custom op result eligible for epilogue fusion
if (hasattr(node1, 'node') and hasattr(node1.node, 'data') and if (
hasattr(node1.node.data, '_custom_op_fusion_metadata')): hasattr(node1, "node")
and hasattr(node1.node, "data")
and hasattr(node1.node.data, "_custom_op_fusion_metadata")
):
metadata = node1.node.data._custom_op_fusion_metadata metadata = node1.node.data._custom_op_fusion_metadata
if metadata.get('epilogue_fusion_enabled', False): if metadata.get("epilogue_fusion_enabled", False):
# Check if node2 is a suitable epilogue operation # Check if node2 is a suitable epilogue operation
if (node2.is_pointwise() and if (
not node2.is_reduction() and node2.is_pointwise()
not node2.has_aliasing_or_mutation()): and not node2.is_reduction()
and not node2.has_aliasing_or_mutation()
fusion_log.info("Custom op epilogue fusion enabled for %s -> %s (custom_op: %s)", ):
node1.get_first_name(), node2.get_first_name(), fusion_log.info(
metadata.get('custom_op_name', 'unknown')) "Custom op epilogue fusion enabled for %s -> %s (custom_op: %s)",
node1.get_first_name(),
node2.get_first_name(),
metadata.get("custom_op_name", "unknown"),
)
return True return True
# Enhanced check: also look for custom ops directly in the node # Enhanced check: also look for custom ops directly in the node
if (hasattr(node1, 'node') and hasattr(node1.node, 'data') and if (
hasattr(node1.node.data, 'name') and hasattr(node1, "node")
hasattr(node1.node.data, 'get_inputs')): and hasattr(node1.node, "data")
and hasattr(node1.node.data, "name")
and hasattr(node1.node.data, "get_inputs")
):
# Check if this is a result from our custom op autotune system # Check if this is a result from our custom op autotune system
if (hasattr(node1.node.data, 'get_name') and if hasattr(node1.node.data, "get_name") and "_autotuned" in str(
'_autotuned' in str(node1.node.data.get_name())): node1.node.data.get_name()
):
# Apply similar checks as template epilogue fusion # Apply similar checks as template epilogue fusion
if (node2.is_pointwise() and if (
not node2.is_reduction() and node2.is_pointwise()
not node2.has_aliasing_or_mutation()): and not node2.is_reduction()
and not node2.has_aliasing_or_mutation()
fusion_log.debug("Custom op epilogue candidate: %s -> %s", ):
node1.get_first_name(), node2.get_first_name()) fusion_log.debug(
"Custom op epilogue candidate: %s -> %s",
node1.get_first_name(),
node2.get_first_name(),
)
return True return True
return False return False
@ -4478,42 +4512,54 @@ class Scheduler:
""" """
# Check if global config enables custom op prologue fusion # Check if global config enables custom op prologue fusion
from torch._inductor import config from torch._inductor import config
if not config.enable_custom_op_prologue_fusion: if not config.enable_custom_op_prologue_fusion:
return False return False
# Check if node2 is marked as a custom op that supports prologue fusion # Check if node2 is marked as a custom op that supports prologue fusion
if (hasattr(node2, 'node') and hasattr(node2.node, 'data') and if (
hasattr(node2.node.data, '_custom_op_fusion_metadata')): hasattr(node2, "node")
and hasattr(node2.node, "data")
and hasattr(node2.node.data, "_custom_op_fusion_metadata")
):
metadata = node2.node.data._custom_op_fusion_metadata metadata = node2.node.data._custom_op_fusion_metadata
if metadata.get('prologue_fusion_enabled', False): if metadata.get("prologue_fusion_enabled", False):
# Check if node1 is a suitable prologue operation # Check if node1 is a suitable prologue operation
if (node1.is_pointwise() and if (
not node1.is_reduction() and node1.is_pointwise()
not node1.has_aliasing_or_mutation()): and not node1.is_reduction()
and not node1.has_aliasing_or_mutation()
fusion_log.info("Custom op prologue fusion enabled for %s -> %s (custom_op: %s)", ):
node1.get_first_name(), node2.get_first_name(), fusion_log.info(
metadata.get('custom_op_name', 'unknown')) "Custom op prologue fusion enabled for %s -> %s (custom_op: %s)",
node1.get_first_name(),
node2.get_first_name(),
metadata.get("custom_op_name", "unknown"),
)
return True return True
# Enhanced check: also look for custom ops directly in the node # Enhanced check: also look for custom ops directly in the node
if (hasattr(node2, 'node') and hasattr(node2.node, 'data') and if (
hasattr(node2.node.data, 'name') and hasattr(node2, "node")
hasattr(node2.node.data, 'get_inputs')): and hasattr(node2.node, "data")
and hasattr(node2.node.data, "name")
and hasattr(node2.node.data, "get_inputs")
):
# Check if this is a result from our custom op autotune system # Check if this is a result from our custom op autotune system
if (hasattr(node2.node.data, 'get_name') and if hasattr(node2.node.data, "get_name") and "_autotuned" in str(
'_autotuned' in str(node2.node.data.get_name())): node2.node.data.get_name()
):
# Apply similar checks as template prologue fusion # Apply similar checks as template prologue fusion
if (node1.is_pointwise() and if (
not node1.is_reduction() and node1.is_pointwise()
not node1.has_aliasing_or_mutation()): and not node1.is_reduction()
and not node1.has_aliasing_or_mutation()
fusion_log.debug("Custom op prologue candidate: %s -> %s", ):
node1.get_first_name(), node2.get_first_name()) fusion_log.debug(
"Custom op prologue candidate: %s -> %s",
node1.get_first_name(),
node2.get_first_name(),
)
return True return True
return False return False