mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fix typo and lint
This commit is contained in:
@ -256,8 +256,12 @@ prologue_fusion = prologue_fusion_enabled()
|
|||||||
epilogue_fusion_first = False
|
epilogue_fusion_first = False
|
||||||
|
|
||||||
# enable custom op fusion support
|
# enable custom op fusion support
|
||||||
enable_custom_op_epilogue_fusion = os.environ.get("TORCHINDUCTOR_CUSTOM_OP_EPILOGUE_FUSION", "1") == "1"
|
enable_custom_op_epilogue_fusion = (
|
||||||
enable_custom_op_prologue_fusion = os.environ.get("TORCHINDUCTOR_CUSTOM_OP_PROLOGUE_FUSION", "1") == "1"
|
os.environ.get("TORCHINDUCTOR_CUSTOM_OP_EPILOGUE_FUSION", "1") == "1"
|
||||||
|
)
|
||||||
|
enable_custom_op_prologue_fusion = (
|
||||||
|
os.environ.get("TORCHINDUCTOR_CUSTOM_OP_PROLOGUE_FUSION", "1") == "1"
|
||||||
|
)
|
||||||
|
|
||||||
# enable pattern match+replace optimizations
|
# enable pattern match+replace optimizations
|
||||||
pattern_matcher = True
|
pattern_matcher = True
|
||||||
|
@ -65,7 +65,6 @@ def _create_user_input_gen_fns(
|
|||||||
|
|
||||||
Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes.
|
Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes.
|
||||||
"""
|
"""
|
||||||
from torch._inductor import config
|
|
||||||
|
|
||||||
name_to_index = {name: i for i, name in enumerate(arg_names)}
|
name_to_index = {name: i for i, name in enumerate(arg_names)}
|
||||||
index_based_fns = {}
|
index_based_fns = {}
|
||||||
|
@ -187,6 +187,16 @@ class SchedulerBuffer:
|
|||||||
input_buffer,
|
input_buffer,
|
||||||
self.node,
|
self.node,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
V.graph.wrapper_code.codegen_allocation(self.node)
|
||||||
|
|
||||||
|
def can_free(self) -> bool:
|
||||||
|
# There's no real allocated buffer, no need to free it
|
||||||
|
assert self.node is not None
|
||||||
|
if isinstance(self.node.layout, ir.NoneLayout) or is_multi_outputs_template(
|
||||||
|
self.node
|
||||||
|
):
|
||||||
|
return False
|
||||||
for use in self.users:
|
for use in self.users:
|
||||||
if isinstance(use.node, OutputNode):
|
if isinstance(use.node, OutputNode):
|
||||||
return False
|
return False
|
||||||
@ -3218,19 +3228,25 @@ class Scheduler:
|
|||||||
for n in (node1, node2)
|
for n in (node1, node2)
|
||||||
)
|
)
|
||||||
# Check for custom op fusion cases
|
# Check for custom op fusion cases
|
||||||
is_custom_op_fusion = (
|
is_custom_op_fusion = self._can_apply_custom_op_epilogue_fusion(
|
||||||
self._can_apply_custom_op_epilogue_fusion(node1, node2) or
|
node1, node2
|
||||||
self._can_apply_custom_op_prologue_fusion(node1, node2)
|
) or self._can_apply_custom_op_prologue_fusion(node1, node2)
|
||||||
)
|
|
||||||
|
|
||||||
if not config.benchmark_fusion and not is_multi_template and not is_custom_op_fusion:
|
if (
|
||||||
|
not config.benchmark_fusion
|
||||||
|
and not is_multi_template
|
||||||
|
and not is_custom_op_fusion
|
||||||
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# For custom op fusion, we want to benchmark by default unless explicitly disabled
|
# For custom op fusion, we want to benchmark by default unless explicitly disabled
|
||||||
if is_custom_op_fusion and not config.benchmark_fusion:
|
if is_custom_op_fusion and not config.benchmark_fusion:
|
||||||
# Still allow benchmark for custom ops even if global benchmark_fusion is off
|
# Still allow benchmark for custom ops even if global benchmark_fusion is off
|
||||||
fusion_log.debug("Benchmarking custom op fusion: %s <-> %s",
|
fusion_log.debug(
|
||||||
node1.get_first_name(), node2.get_first_name())
|
"Benchmarking custom op fusion: %s <-> %s",
|
||||||
|
node1.get_first_name(),
|
||||||
|
node2.get_first_name(),
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
node1.is_template()
|
node1.is_template()
|
||||||
@ -4287,8 +4303,11 @@ class Scheduler:
|
|||||||
|
|
||||||
# Check for custom op prologue fusion
|
# Check for custom op prologue fusion
|
||||||
if self._can_apply_custom_op_prologue_fusion(node1, node2):
|
if self._can_apply_custom_op_prologue_fusion(node1, node2):
|
||||||
fusion_log.debug("Custom op prologue fusion applicable for %s -> %s",
|
fusion_log.debug(
|
||||||
node1.get_first_name(), node2.get_first_name())
|
"Custom op prologue fusion applicable for %s -> %s",
|
||||||
|
node1.get_first_name(),
|
||||||
|
node2.get_first_name(),
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if node1.is_template() and (
|
if node1.is_template() and (
|
||||||
@ -4363,8 +4382,11 @@ class Scheduler:
|
|||||||
"""
|
"""
|
||||||
# Check for custom op epilogue fusion opportunities first
|
# Check for custom op epilogue fusion opportunities first
|
||||||
if self._can_apply_custom_op_epilogue_fusion(node1, node2):
|
if self._can_apply_custom_op_epilogue_fusion(node1, node2):
|
||||||
fusion_log.debug("Custom op epilogue fusion applicable for %s -> %s",
|
fusion_log.debug(
|
||||||
node1.get_first_name(), node2.get_first_name())
|
"Custom op epilogue fusion applicable for %s -> %s",
|
||||||
|
node1.get_first_name(),
|
||||||
|
node2.get_first_name(),
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
node1_buf_names = node1.get_buffer_names()
|
node1_buf_names = node1.get_buffer_names()
|
||||||
@ -4424,42 +4446,54 @@ class Scheduler:
|
|||||||
"""
|
"""
|
||||||
# Check if global config enables custom op epilogue fusion
|
# Check if global config enables custom op epilogue fusion
|
||||||
from torch._inductor import config
|
from torch._inductor import config
|
||||||
|
|
||||||
if not config.enable_custom_op_epilogue_fusion:
|
if not config.enable_custom_op_epilogue_fusion:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check if node1 is marked as a custom op result eligible for epilogue fusion
|
# Check if node1 is marked as a custom op result eligible for epilogue fusion
|
||||||
if (hasattr(node1, 'node') and hasattr(node1.node, 'data') and
|
if (
|
||||||
hasattr(node1.node.data, '_custom_op_fusion_metadata')):
|
hasattr(node1, "node")
|
||||||
|
and hasattr(node1.node, "data")
|
||||||
|
and hasattr(node1.node.data, "_custom_op_fusion_metadata")
|
||||||
|
):
|
||||||
metadata = node1.node.data._custom_op_fusion_metadata
|
metadata = node1.node.data._custom_op_fusion_metadata
|
||||||
if metadata.get('epilogue_fusion_enabled', False):
|
if metadata.get("epilogue_fusion_enabled", False):
|
||||||
|
|
||||||
# Check if node2 is a suitable epilogue operation
|
# Check if node2 is a suitable epilogue operation
|
||||||
if (node2.is_pointwise() and
|
if (
|
||||||
not node2.is_reduction() and
|
node2.is_pointwise()
|
||||||
not node2.has_aliasing_or_mutation()):
|
and not node2.is_reduction()
|
||||||
|
and not node2.has_aliasing_or_mutation()
|
||||||
fusion_log.info("Custom op epilogue fusion enabled for %s -> %s (custom_op: %s)",
|
):
|
||||||
node1.get_first_name(), node2.get_first_name(),
|
fusion_log.info(
|
||||||
metadata.get('custom_op_name', 'unknown'))
|
"Custom op epilogue fusion enabled for %s -> %s (custom_op: %s)",
|
||||||
|
node1.get_first_name(),
|
||||||
|
node2.get_first_name(),
|
||||||
|
metadata.get("custom_op_name", "unknown"),
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Enhanced check: also look for custom ops directly in the node
|
# Enhanced check: also look for custom ops directly in the node
|
||||||
if (hasattr(node1, 'node') and hasattr(node1.node, 'data') and
|
if (
|
||||||
hasattr(node1.node.data, 'name') and
|
hasattr(node1, "node")
|
||||||
hasattr(node1.node.data, 'get_inputs')):
|
and hasattr(node1.node, "data")
|
||||||
|
and hasattr(node1.node.data, "name")
|
||||||
|
and hasattr(node1.node.data, "get_inputs")
|
||||||
|
):
|
||||||
# Check if this is a result from our custom op autotune system
|
# Check if this is a result from our custom op autotune system
|
||||||
if (hasattr(node1.node.data, 'get_name') and
|
if hasattr(node1.node.data, "get_name") and "_autotuned" in str(
|
||||||
'_autotuned' in str(node1.node.data.get_name())):
|
node1.node.data.get_name()
|
||||||
|
):
|
||||||
# Apply similar checks as template epilogue fusion
|
# Apply similar checks as template epilogue fusion
|
||||||
if (node2.is_pointwise() and
|
if (
|
||||||
not node2.is_reduction() and
|
node2.is_pointwise()
|
||||||
not node2.has_aliasing_or_mutation()):
|
and not node2.is_reduction()
|
||||||
|
and not node2.has_aliasing_or_mutation()
|
||||||
fusion_log.debug("Custom op epilogue candidate: %s -> %s",
|
):
|
||||||
node1.get_first_name(), node2.get_first_name())
|
fusion_log.debug(
|
||||||
|
"Custom op epilogue candidate: %s -> %s",
|
||||||
|
node1.get_first_name(),
|
||||||
|
node2.get_first_name(),
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@ -4478,42 +4512,54 @@ class Scheduler:
|
|||||||
"""
|
"""
|
||||||
# Check if global config enables custom op prologue fusion
|
# Check if global config enables custom op prologue fusion
|
||||||
from torch._inductor import config
|
from torch._inductor import config
|
||||||
|
|
||||||
if not config.enable_custom_op_prologue_fusion:
|
if not config.enable_custom_op_prologue_fusion:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check if node2 is marked as a custom op that supports prologue fusion
|
# Check if node2 is marked as a custom op that supports prologue fusion
|
||||||
if (hasattr(node2, 'node') and hasattr(node2.node, 'data') and
|
if (
|
||||||
hasattr(node2.node.data, '_custom_op_fusion_metadata')):
|
hasattr(node2, "node")
|
||||||
|
and hasattr(node2.node, "data")
|
||||||
|
and hasattr(node2.node.data, "_custom_op_fusion_metadata")
|
||||||
|
):
|
||||||
metadata = node2.node.data._custom_op_fusion_metadata
|
metadata = node2.node.data._custom_op_fusion_metadata
|
||||||
if metadata.get('prologue_fusion_enabled', False):
|
if metadata.get("prologue_fusion_enabled", False):
|
||||||
|
|
||||||
# Check if node1 is a suitable prologue operation
|
# Check if node1 is a suitable prologue operation
|
||||||
if (node1.is_pointwise() and
|
if (
|
||||||
not node1.is_reduction() and
|
node1.is_pointwise()
|
||||||
not node1.has_aliasing_or_mutation()):
|
and not node1.is_reduction()
|
||||||
|
and not node1.has_aliasing_or_mutation()
|
||||||
fusion_log.info("Custom op prologue fusion enabled for %s -> %s (custom_op: %s)",
|
):
|
||||||
node1.get_first_name(), node2.get_first_name(),
|
fusion_log.info(
|
||||||
metadata.get('custom_op_name', 'unknown'))
|
"Custom op prologue fusion enabled for %s -> %s (custom_op: %s)",
|
||||||
|
node1.get_first_name(),
|
||||||
|
node2.get_first_name(),
|
||||||
|
metadata.get("custom_op_name", "unknown"),
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Enhanced check: also look for custom ops directly in the node
|
# Enhanced check: also look for custom ops directly in the node
|
||||||
if (hasattr(node2, 'node') and hasattr(node2.node, 'data') and
|
if (
|
||||||
hasattr(node2.node.data, 'name') and
|
hasattr(node2, "node")
|
||||||
hasattr(node2.node.data, 'get_inputs')):
|
and hasattr(node2.node, "data")
|
||||||
|
and hasattr(node2.node.data, "name")
|
||||||
|
and hasattr(node2.node.data, "get_inputs")
|
||||||
|
):
|
||||||
# Check if this is a result from our custom op autotune system
|
# Check if this is a result from our custom op autotune system
|
||||||
if (hasattr(node2.node.data, 'get_name') and
|
if hasattr(node2.node.data, "get_name") and "_autotuned" in str(
|
||||||
'_autotuned' in str(node2.node.data.get_name())):
|
node2.node.data.get_name()
|
||||||
|
):
|
||||||
# Apply similar checks as template prologue fusion
|
# Apply similar checks as template prologue fusion
|
||||||
if (node1.is_pointwise() and
|
if (
|
||||||
not node1.is_reduction() and
|
node1.is_pointwise()
|
||||||
not node1.has_aliasing_or_mutation()):
|
and not node1.is_reduction()
|
||||||
|
and not node1.has_aliasing_or_mutation()
|
||||||
fusion_log.debug("Custom op prologue candidate: %s -> %s",
|
):
|
||||||
node1.get_first_name(), node2.get_first_name())
|
fusion_log.debug(
|
||||||
|
"Custom op prologue candidate: %s -> %s",
|
||||||
|
node1.get_first_name(),
|
||||||
|
node2.get_first_name(),
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
Reference in New Issue
Block a user