From f194b740f7e558b133757bede342b3a0739e671a Mon Sep 17 00:00:00 2001 From: Tianren Gao Date: Thu, 16 Oct 2025 00:48:02 -0700 Subject: [PATCH] resolve comments:fix input_gen_fns API; tensor shape infer with siz_hints instead of ir_node_to_tensor; seperated non_tensor_args for Functools.partial for kwargs ; added Single output - layout assertion; lint --- test/inductor/test_custom_op_autotune.py | 108 ++++++++++------- torch/_inductor/codegen/subgraph.py | 50 +++++++- torch/_inductor/kernel/custom_op.py | 146 ++++++++++++----------- 3 files changed, 188 insertions(+), 116 deletions(-) diff --git a/test/inductor/test_custom_op_autotune.py b/test/inductor/test_custom_op_autotune.py index 11394de04d91..01f3a9a1387e 100644 --- a/test/inductor/test_custom_op_autotune.py +++ b/test/inductor/test_custom_op_autotune.py @@ -1,17 +1,14 @@ # Owner(s): ["module: inductor"] """ -Test suite for custom operation autotuning integration with PyTorch Inductor. +Tests for custom operation autotuning with PyTorch Inductor. -This module tests the custom op autotuning system, which allows users to provide -multiple decomposition implementations of custom operations and automatically -select the best performing one through Inductor's autotuning system. +Users can register custom ops with multiple decomposition implementations and let +Inductor automatically select the best performing variant. Key features tested: -The tests cover: -1. Custom op registration and autotuning for different operations (RMSNorm, MLP, Attention) -2. Numerical equivalence between different decomposition implementations -3. End-to-end compilation and performance validation -4. Fallback behavior when decompositions fail -5. Max-autotune integration with extended configuration sets +- Name-based input generators (use argument names instead of indices) +- Dynamic shape handling across multiple compilations +- Parametric tuning with tuning_knob for combinatorial parameter exploration +- Numerical correctness and performance validation """ import torch @@ -154,8 +151,8 @@ class TestCustomOpAutoTune(TestCase): return input_tensor, gate_weight, up_weight, down_weight @skipIfXpu - def test_rmsnorm_custom_op_autotune(self): - """Test RMSNorm autotuning with multiple decomposition variants showcasing different performance characteristics.""" + def test_rmsnorm_custom_op_autotune_with_dynamic_shape(self): + """Test RMSNorm autotuning decomposition variants compared to fallback default with dynamic shapes.""" test_op_name = f"test_lib::rmsnorm_{id(self)}" def rmsnorm_decomposition1( @@ -216,31 +213,42 @@ class TestCustomOpAutoTune(TestCase): rmsnorm_decomposition3, ] - # Example of user-friendly input generation functions register_custom_op_autotuning( op_object.default, decompositions=decompositions, name="test_rmsnorm_autotuned", input_gen_fns={ - 0: lambda fake_tensor: torch.randn_like(fake_tensor, device=self.device) - * 0.02, # Small values for input - 1: lambda fake_tensor: torch.ones_like( - fake_tensor, device=self.device - ), # Ones for weight + "x": lambda x: torch.randn_like(x, device=self.device) * 0.02, + "weight": lambda weight: torch.ones_like(weight, device=self.device), }, ) - # Test inputs - input_tensor, weight = self._create_rmsnorm_inputs() + # Test multiple shapes to verify dynamic shape handling + test_shapes = [(2, 16, 128), (8, 32, 256)] - # Test numerical equivalence for all decompositions - self._assert_implementations_equivalent( - decompositions, (input_tensor, weight), "RMSNorm" - ) + for i, (batch_size, seq_len, hidden_dim) in enumerate(test_shapes): + input_tensor = torch.randn( + batch_size, + seq_len, + hidden_dim, + device=self.device, + dtype=self.dtype, + requires_grad=False, + ) + weight = torch.randn( + hidden_dim, device=self.device, dtype=self.dtype, requires_grad=False + ) - # Test autotuning - expected = rmsnorm_decomposition1(input_tensor, weight) - self._run_autotune_test(op_object, (input_tensor, weight), expected, "RMSNorm") + # Test numerical equivalence for all decompositions + self._assert_implementations_equivalent( + decompositions, (input_tensor, weight), f"RMSNorm_{i}" + ) + + # Test autotuning + expected = rmsnorm_decomposition1(input_tensor, weight) + self._run_autotune_test( + op_object, (input_tensor, weight), expected, f"RMSNorm_{i}" + ) @skipIfXpu def test_mlp_custom_op_autotune(self): @@ -327,14 +335,22 @@ class TestCustomOpAutoTune(TestCase): decompositions=decompositions, name="test_mlp_autotuned", input_gen_fns={ - 0: lambda fake_tensor: torch.randn_like(fake_tensor, device=self.device) - * 0.1, # Input tensor - 1: lambda fake_tensor: torch.randn_like(fake_tensor, device=self.device) - * 0.05, # Gate weight - 2: lambda fake_tensor: torch.randn_like(fake_tensor, device=self.device) - * 0.05, # Up weight - 3: lambda fake_tensor: torch.randn_like(fake_tensor, device=self.device) - * 0.05, # Down weight + "input_tensor": lambda fake_tensor: torch.randn_like( + fake_tensor, device=self.device + ) + * 0.1, + "gate_weight": lambda fake_tensor: torch.randn_like( + fake_tensor, device=self.device + ) + * 0.05, + "up_weight": lambda fake_tensor: torch.randn_like( + fake_tensor, device=self.device + ) + * 0.05, + "down_weight": lambda fake_tensor: torch.randn_like( + fake_tensor, device=self.device + ) + * 0.05, }, ) @@ -408,10 +424,14 @@ class TestCustomOpAutoTune(TestCase): max_autotune_configs={"k_splits": [8, 16, 128, 256, 512]}, name="test_decompose_k_autotuned", input_gen_fns={ - 0: lambda fake_tensor: torch.randn_like(fake_tensor, device=self.device) - * 0.1, - 1: lambda fake_tensor: torch.randn_like(fake_tensor, device=self.device) - * 0.1, + "a": lambda fake_tensor: torch.randn_like( + fake_tensor, device=self.device + ) + * 0.1, # Matrix A + "b": lambda fake_tensor: torch.randn_like( + fake_tensor, device=self.device + ) + * 0.1, # Matrix B }, ) @@ -498,8 +518,8 @@ class TestCustomOpAutoTune(TestCase): tuning_knob={"method": [0, 1, 2, 3, 4]}, name="parametric_norm_autotuned", input_gen_fns={ - 0: lambda t: torch.randn_like(t, device=self.device) * 0.1, - 1: lambda t: torch.ones_like(t, device=self.device), + "x": lambda t: torch.randn_like(t, device=self.device) * 0.1, + "weight": lambda t: torch.ones_like(t, device=self.device), }, ) @@ -579,8 +599,10 @@ class TestCustomOpAutoTune(TestCase): tuning_knob={"scale_mode": [1, 2, 3], "chunk_size": [16, 32]}, name="multi_param_autotuned", input_gen_fns={ - 0: lambda t: torch.randn_like(t, device=self.device) * 0.1, - 1: lambda t: torch.ones(t.shape[-1], device=self.device, dtype=t.dtype), + "x": lambda t: torch.randn_like(t, device=self.device) * 0.1, + "factor": lambda t: torch.ones( + t.shape[-1], device=self.device, dtype=t.dtype + ), }, ) diff --git a/torch/_inductor/codegen/subgraph.py b/torch/_inductor/codegen/subgraph.py index bdfada9e9ad8..85b97888faef 100644 --- a/torch/_inductor/codegen/subgraph.py +++ b/torch/_inductor/codegen/subgraph.py @@ -8,6 +8,7 @@ from torch._inductor import ir from torch._inductor.codegen.common import KernelTemplate from torch._inductor.ir import ( Buffer, + FixedLayout, get_free_symbols, get_symbolic_inputs, gm_original_output_strides, @@ -249,7 +250,17 @@ class SubgraphTemplate(KernelTemplate): ] self._validate_stride_consistency(name, decompositions, layouts) - layout = layouts[0] # All layouts have equivalent stride now + + # Assert single output layout - assumes custom ops have one output tensor + assert len(layouts) > 0, f"No layouts inferred for custom op '{name}'" + assert all( + layout.device == layouts[0].device + and layout.dtype == layouts[0].dtype + and layout.size == layouts[0].size + for layout in layouts + ), f"All decompositions for '{name}' must produce equivalent output layouts" + + layout = layouts[0] # All layouts have equivalent stride/shape/dtype now choices = [] for decomp in decompositions: @@ -302,20 +313,49 @@ class SubgraphTemplate(KernelTemplate): kwargs: dict[str, Any], default_impl: Optional[Callable[..., Any]] = None, ) -> Layout: - """Infer output layout for custom ops using the default implementation when available.""" + """Infer output layout for custom ops using the default implementation when available. + + Note that the Subgraph assumes custom ops return exactly one tensor so far. + TODO: Add support for multiple output custom ops. + """ import functools - from torch._inductor.ir import FixedLayout, ir_node_to_tensor from torch._inductor.virtualized import V + # Assert kwargs contain only non-tensor arguments for functools.partial + for key, value in kwargs.items(): + assert not isinstance(value, (torch.Tensor, Buffer)), ( + f"kwargs['{key}'] contains tensor {type(value)}. " + f"Tensor arguments should be in input_nodes, not kwargs. " + f"Only scalar/non-tensor parameters should be in kwargs." + ) + # Use default_impl if available, otherwise use first decomposition impl = default_impl if default_impl is not None else decompositions[0] with V.fake_mode: - example_inputs = [ir_node_to_tensor(inp) for inp in input_nodes] - fn = functools.partial(impl, **kwargs) + example_inputs = [] + for inp in input_nodes: + raw_shape = inp.get_size() + concrete_shape = V.graph.sizevars.size_hints( + raw_shape, fallback=config.unbacked_symint_fallback + ) + fake_tensor = torch.empty( + concrete_shape, dtype=inp.get_dtype(), device=inp.get_device() + ) + example_inputs.append(fake_tensor) + + fn = functools.partial( + impl, **kwargs + ) # kwargs must be non-tensor for partial output = fn(*example_inputs) + # Assert single output + assert isinstance(output, torch.Tensor), ( + f"Expected single tensor output, got {type(output)}. " + f"Multi-output custom ops not yet supported in autotuning." + ) + return FixedLayout( device=output.device, dtype=output.dtype, diff --git a/torch/_inductor/kernel/custom_op.py b/torch/_inductor/kernel/custom_op.py index ca868ef8f65e..7d1be15da121 100644 --- a/torch/_inductor/kernel/custom_op.py +++ b/torch/_inductor/kernel/custom_op.py @@ -25,6 +25,8 @@ def _extract_tensor_inputs( args: tuple[Any, ...], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: """Extract tensor inputs from mixed args/kwargs. + Separates tensors (for autotuning input_nodes) from non-tensor parameters. + Non-tensor kwargs are later functools.partial'd into decomposition functions. Args: args: Positional arguments (mix of tensors and scalars) @@ -36,16 +38,16 @@ def _extract_tensor_inputs( tensor_inputs = [] non_tensor_kwargs = {} - for arg in args: - if isinstance(arg, (TensorBox, Buffer)) or ( - hasattr(arg, "dtype") and hasattr(arg, "shape") - ): + # Process args and kwargs: separate tensor inputs and non tensor args + for i, arg in enumerate(args): + if isinstance(arg, (TensorBox, Buffer)): tensor_inputs.append(arg) + else: + # Add non-tensor positional args to kwargs with generated names + non_tensor_kwargs[f"arg_{i}"] = arg for key, value in kwargs.items(): - if isinstance(value, (TensorBox, Buffer)) or ( - hasattr(value, "dtype") and hasattr(value, "shape") - ): + if isinstance(value, (TensorBox, Buffer)): tensor_inputs.append(value) else: non_tensor_kwargs[key] = value @@ -55,46 +57,50 @@ def _extract_tensor_inputs( def _create_user_input_gen_fns( inputs: list[Any], - user_input_gen_fns: dict[int, Callable[[torch.Tensor], torch.Tensor]], + arg_names: list[str], + user_input_gen_fns: dict[str, Callable[[torch.Tensor], torch.Tensor]], ) -> dict[int, Callable[[Any], torch.Tensor]]: - """Convert user input generators to internal format. + """Convert user input generators from name-based to index-based format. + Inductor autotune's input_gen_fns expects index of arg_names as key. - Args: - inputs: List of input IR nodes from compilation - user_input_gen_fns: User-provided input generation functions - - Returns: - Dict mapping indices to internal input generation functions + Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes. """ - internal_input_gen_fns = {} + from torch._inductor import config - with V.fake_mode: - fake_inputs = [ir_node_to_tensor(inp) for inp in inputs] + name_to_index = {name: i for i, name in enumerate(arg_names)} + index_based_fns = {} + + for name, gen_fn in user_input_gen_fns.items(): + if name in name_to_index: + index_based_fns[name_to_index[name]] = gen_fn + else: + print(f"Warning: Unknown argument name '{name}' in input_gen_fns") def create_internal_input_gen_fn( - user_function: Callable[[torch.Tensor], torch.Tensor], - template: torch.Tensor, + user_function: Callable[[torch.Tensor], torch.Tensor], arg_name: str ) -> Callable[[Any], torch.Tensor]: + """Create internal input generator that converts IR buffer to user's fake tensor.""" + def internal_input_gen_fn(ir_buffer: Any) -> torch.Tensor: - fake_tensor_for_user = torch.empty( - template.shape, - dtype=template.dtype, - device="meta", + raw_shape = ir_buffer.get_size() + concrete_shape = V.graph.sizevars.size_hints( + raw_shape, fallback=config.unbacked_symint_fallback ) - return user_function(fake_tensor_for_user) + + fake_tensor = torch.empty( + concrete_shape, dtype=ir_buffer.get_dtype(), device="meta" + ) + return user_function(fake_tensor) return internal_input_gen_fn - for i, user_gen_fn in user_input_gen_fns.items(): - if i >= len(fake_inputs): - continue - - fake_template = fake_inputs[i] - internal_input_gen_fns[i] = create_internal_input_gen_fn( - user_gen_fn, fake_template + return { + i: create_internal_input_gen_fn( + user_gen_fn, arg_names[i] if i < len(arg_names) else f"arg_{i}" ) - - return internal_input_gen_fns + for i, user_gen_fn in index_based_fns.items() + if i < len(inputs) + } # Global cache for fallback choices to avoid duplicate creation @@ -107,13 +113,7 @@ def _get_or_create_fallback_choice( fake_output: torch.Tensor, kwargs: dict[str, Any], ) -> ExternKernelChoice: - """Get or create fallback choice for default implementation.""" - cache_key = (id(default_impl), name, tuple(sorted(kwargs.items()))) - - if cache_key not in _fallback_choice_cache: - - def fallback_wrapper(*args: Any) -> Any: - return default_impl(*args, **kwargs) + """Create fallback choice for default implementation.""" fallback_name = f"{name}_fallback_{default_impl._name}" _fallback_choice_cache[cache_key] = ExternKernelChoice( @@ -142,14 +142,10 @@ def _create_parameter_variants( """ # Validate parameter values for param_name, param_values in tuning_knob.items(): - if not isinstance(param_values, (list, tuple)): + if not param_values or not isinstance(param_values, (list, tuple)): raise TypeError( f"Parameter values for '{param_name}' must be a list or tuple, got {type(param_values)}" ) - if not param_values: - raise ValueError( - f"At least one parameter value must be provided for '{param_name}'" - ) # Generate all combinations of parameter values using Cartesian product import itertools @@ -167,8 +163,6 @@ def _create_parameter_variants( # Create partial function with all parameters variant = functools.partial(decomp_fn, **param_kwargs) - - # Generate descriptive name param_suffix = "_".join( f"{name}_{value}" for name, value in param_kwargs.items() ) @@ -185,11 +179,14 @@ def autotune_custom_op( kwargs: Optional[dict[str, Any]] = None, default_impl: Optional[Callable[..., Any]] = None, user_input_gen_fns: Optional[ - dict[int, Callable[[torch.Tensor], torch.Tensor]] + dict[str, Callable[[torch.Tensor], torch.Tensor]] ] = None, ) -> Union[TensorBox, Any]: """Autotune custom operations by comparing multiple decomposition implementations. + Currently supports SINGLE OUTPUT custom ops only. + TODO: Add support for multiple output custom ops (tuple/list returns). + This function generates multiple implementation choices for a custom operation and uses Inductor's autotuning system to select the best performing variant at runtime. @@ -231,24 +228,28 @@ def autotune_custom_op( # Add default implementation as fallback if default_impl and hasattr(default_impl, "_op"): - # Get output shape/dtype by calling default implementation with fake inputs - with V.fake_mode: - fake_inputs = [ir_node_to_tensor(inp) for inp in inputs] - fake_output = default_impl(*fake_inputs, **kwargs) + fallback_name = f"{name}_fallback_default" + from torch._inductor.select_algorithm import extern_kernels - fallback_choice = _get_or_create_fallback_choice( - name, default_impl, fake_output, kwargs - ) - fallback_choice.maybe_append_choice( - choices=choices, - input_nodes=list(inputs), - layout=FixedLayout( - device=fake_output.device, - dtype=fake_output.dtype, - size=fake_output.shape, - stride=fake_output.stride(), - ), - ) + # Skip if extern_kernel already registered to avoid duplicate registration error + if not hasattr(extern_kernels, fallback_name): + with V.fake_mode: + fake_inputs = [ir_node_to_tensor(inp) for inp in inputs] + fake_output = default_impl(*fake_inputs, **kwargs) + + fallback_choice = _create_fallback_choice( + name, default_impl, fake_output, kwargs + ) + fallback_choice.maybe_append_choice( + choices=choices, + input_nodes=list(inputs), + layout=FixedLayout( + device=fake_output.device, + dtype=fake_output.dtype, + size=fake_output.shape, + stride=fake_output.stride(), + ), + ) if not choices: raise RuntimeError(f"No valid choices generated for {name}") @@ -256,7 +257,16 @@ def autotune_custom_op( # Convert user input generation functions to internal format input_gen_fns = {} if user_input_gen_fns: - input_gen_fns = _create_user_input_gen_fns(inputs, user_input_gen_fns) + import inspect + + arg_names = ( + list(inspect.signature(decompositions[0]).parameters.keys()) + if decompositions + else [] + ) + input_gen_fns = _create_user_input_gen_fns( + inputs, arg_names, user_input_gen_fns + ) return autotune_select_algorithm( name=name, @@ -271,7 +281,7 @@ def register_custom_op_autotuning( custom_op: torch._ops.OpOverload, decompositions: list[Callable[..., Any]], name: Optional[str] = None, - input_gen_fns: Optional[dict[int, Callable[[torch.Tensor], torch.Tensor]]] = None, + input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]] = None, tuning_knob: Optional[dict[str, list[Any]]] = None, max_autotune_configs: Optional[dict[str, list[Any]]] = None, ) -> None: