diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 062bc42be40c..ab0070027710 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -256,8 +256,12 @@ prologue_fusion = prologue_fusion_enabled() epilogue_fusion_first = False # enable custom op fusion support -enable_custom_op_epilogue_fusion = os.environ.get("TORCHINDUCTOR_CUSTOM_OP_EPILOGUE_FUSION", "1") == "1" -enable_custom_op_prologue_fusion = os.environ.get("TORCHINDUCTOR_CUSTOM_OP_PROLOGUE_FUSION", "1") == "1" +enable_custom_op_epilogue_fusion = ( + os.environ.get("TORCHINDUCTOR_CUSTOM_OP_EPILOGUE_FUSION", "1") == "1" +) +enable_custom_op_prologue_fusion = ( + os.environ.get("TORCHINDUCTOR_CUSTOM_OP_PROLOGUE_FUSION", "1") == "1" +) # enable pattern match+replace optimizations pattern_matcher = True diff --git a/torch/_inductor/kernel/custom_op.py b/torch/_inductor/kernel/custom_op.py index 6c3cfd91bb65..a5476af04636 100644 --- a/torch/_inductor/kernel/custom_op.py +++ b/torch/_inductor/kernel/custom_op.py @@ -65,7 +65,6 @@ def _create_user_input_gen_fns( Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes. """ - from torch._inductor import config name_to_index = {name: i for i, name in enumerate(arg_names)} index_based_fns = {} diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 18921523a327..da316f53ff59 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -187,6 +187,16 @@ class SchedulerBuffer: input_buffer, self.node, ) + else: + V.graph.wrapper_code.codegen_allocation(self.node) + + def can_free(self) -> bool: + # There's no real allocated buffer, no need to free it + assert self.node is not None + if isinstance(self.node.layout, ir.NoneLayout) or is_multi_outputs_template( + self.node + ): + return False for use in self.users: if isinstance(use.node, OutputNode): return False @@ -3218,19 +3228,25 @@ class Scheduler: for n in (node1, node2) ) # Check for custom op fusion cases - is_custom_op_fusion = ( - self._can_apply_custom_op_epilogue_fusion(node1, node2) or - self._can_apply_custom_op_prologue_fusion(node1, node2) - ) + is_custom_op_fusion = self._can_apply_custom_op_epilogue_fusion( + node1, node2 + ) or self._can_apply_custom_op_prologue_fusion(node1, node2) - if not config.benchmark_fusion and not is_multi_template and not is_custom_op_fusion: + if ( + not config.benchmark_fusion + and not is_multi_template + and not is_custom_op_fusion + ): return True # For custom op fusion, we want to benchmark by default unless explicitly disabled if is_custom_op_fusion and not config.benchmark_fusion: # Still allow benchmark for custom ops even if global benchmark_fusion is off - fusion_log.debug("Benchmarking custom op fusion: %s <-> %s", - node1.get_first_name(), node2.get_first_name()) + fusion_log.debug( + "Benchmarking custom op fusion: %s <-> %s", + node1.get_first_name(), + node2.get_first_name(), + ) if ( node1.is_template() @@ -4287,8 +4303,11 @@ class Scheduler: # Check for custom op prologue fusion if self._can_apply_custom_op_prologue_fusion(node1, node2): - fusion_log.debug("Custom op prologue fusion applicable for %s -> %s", - node1.get_first_name(), node2.get_first_name()) + fusion_log.debug( + "Custom op prologue fusion applicable for %s -> %s", + node1.get_first_name(), + node2.get_first_name(), + ) return True if node1.is_template() and ( @@ -4363,8 +4382,11 @@ class Scheduler: """ # Check for custom op epilogue fusion opportunities first if self._can_apply_custom_op_epilogue_fusion(node1, node2): - fusion_log.debug("Custom op epilogue fusion applicable for %s -> %s", - node1.get_first_name(), node2.get_first_name()) + fusion_log.debug( + "Custom op epilogue fusion applicable for %s -> %s", + node1.get_first_name(), + node2.get_first_name(), + ) return True node1_buf_names = node1.get_buffer_names() @@ -4424,42 +4446,54 @@ class Scheduler: """ # Check if global config enables custom op epilogue fusion from torch._inductor import config + if not config.enable_custom_op_epilogue_fusion: return False # Check if node1 is marked as a custom op result eligible for epilogue fusion - if (hasattr(node1, 'node') and hasattr(node1.node, 'data') and - hasattr(node1.node.data, '_custom_op_fusion_metadata')): - + if ( + hasattr(node1, "node") + and hasattr(node1.node, "data") + and hasattr(node1.node.data, "_custom_op_fusion_metadata") + ): metadata = node1.node.data._custom_op_fusion_metadata - if metadata.get('epilogue_fusion_enabled', False): - + if metadata.get("epilogue_fusion_enabled", False): # Check if node2 is a suitable epilogue operation - if (node2.is_pointwise() and - not node2.is_reduction() and - not node2.has_aliasing_or_mutation()): - - fusion_log.info("Custom op epilogue fusion enabled for %s -> %s (custom_op: %s)", - node1.get_first_name(), node2.get_first_name(), - metadata.get('custom_op_name', 'unknown')) + if ( + node2.is_pointwise() + and not node2.is_reduction() + and not node2.has_aliasing_or_mutation() + ): + fusion_log.info( + "Custom op epilogue fusion enabled for %s -> %s (custom_op: %s)", + node1.get_first_name(), + node2.get_first_name(), + metadata.get("custom_op_name", "unknown"), + ) return True # Enhanced check: also look for custom ops directly in the node - if (hasattr(node1, 'node') and hasattr(node1.node, 'data') and - hasattr(node1.node.data, 'name') and - hasattr(node1.node.data, 'get_inputs')): - + if ( + hasattr(node1, "node") + and hasattr(node1.node, "data") + and hasattr(node1.node.data, "name") + and hasattr(node1.node.data, "get_inputs") + ): # Check if this is a result from our custom op autotune system - if (hasattr(node1.node.data, 'get_name') and - '_autotuned' in str(node1.node.data.get_name())): - + if hasattr(node1.node.data, "get_name") and "_autotuned" in str( + node1.node.data.get_name() + ): # Apply similar checks as template epilogue fusion - if (node2.is_pointwise() and - not node2.is_reduction() and - not node2.has_aliasing_or_mutation()): - - fusion_log.debug("Custom op epilogue candidate: %s -> %s", - node1.get_first_name(), node2.get_first_name()) + if ( + node2.is_pointwise() + and not node2.is_reduction() + and not node2.has_aliasing_or_mutation() + ): + fusion_log.debug( + "Custom op epilogue candidate: %s -> %s", + node1.get_first_name(), + node2.get_first_name(), + ) return True return False @@ -4478,42 +4512,54 @@ class Scheduler: """ # Check if global config enables custom op prologue fusion from torch._inductor import config + if not config.enable_custom_op_prologue_fusion: return False # Check if node2 is marked as a custom op that supports prologue fusion - if (hasattr(node2, 'node') and hasattr(node2.node, 'data') and - hasattr(node2.node.data, '_custom_op_fusion_metadata')): - + if ( + hasattr(node2, "node") + and hasattr(node2.node, "data") + and hasattr(node2.node.data, "_custom_op_fusion_metadata") + ): metadata = node2.node.data._custom_op_fusion_metadata - if metadata.get('prologue_fusion_enabled', False): - + if metadata.get("prologue_fusion_enabled", False): # Check if node1 is a suitable prologue operation - if (node1.is_pointwise() and - not node1.is_reduction() and - not node1.has_aliasing_or_mutation()): - - fusion_log.info("Custom op prologue fusion enabled for %s -> %s (custom_op: %s)", - node1.get_first_name(), node2.get_first_name(), - metadata.get('custom_op_name', 'unknown')) + if ( + node1.is_pointwise() + and not node1.is_reduction() + and not node1.has_aliasing_or_mutation() + ): + fusion_log.info( + "Custom op prologue fusion enabled for %s -> %s (custom_op: %s)", + node1.get_first_name(), + node2.get_first_name(), + metadata.get("custom_op_name", "unknown"), + ) return True # Enhanced check: also look for custom ops directly in the node - if (hasattr(node2, 'node') and hasattr(node2.node, 'data') and - hasattr(node2.node.data, 'name') and - hasattr(node2.node.data, 'get_inputs')): - + if ( + hasattr(node2, "node") + and hasattr(node2.node, "data") + and hasattr(node2.node.data, "name") + and hasattr(node2.node.data, "get_inputs") + ): # Check if this is a result from our custom op autotune system - if (hasattr(node2.node.data, 'get_name') and - '_autotuned' in str(node2.node.data.get_name())): - + if hasattr(node2.node.data, "get_name") and "_autotuned" in str( + node2.node.data.get_name() + ): # Apply similar checks as template prologue fusion - if (node1.is_pointwise() and - not node1.is_reduction() and - not node1.has_aliasing_or_mutation()): - - fusion_log.debug("Custom op prologue candidate: %s -> %s", - node1.get_first_name(), node2.get_first_name()) + if ( + node1.is_pointwise() + and not node1.is_reduction() + and not node1.has_aliasing_or_mutation() + ): + fusion_log.debug( + "Custom op prologue candidate: %s -> %s", + node1.get_first_name(), + node2.get_first_name(), + ) return True return False