mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	[Inductor] Add triton.autotune support for user defined triton kernels with complex grids (#112290)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112290 Approved by: https://github.com/jansel
This commit is contained in:
		
				
					committed by
					
						
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							5a1a9dc354
						
					
				
				
					commit
					1250032c2e
				
			@ -689,31 +689,59 @@ class TritonKernelVariable(VariableTracker):
 | 
			
		||||
    def call_function(
 | 
			
		||||
        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
 | 
			
		||||
    ) -> "VariableTracker":
 | 
			
		||||
        from triton.runtime.autotuner import Autotuner
 | 
			
		||||
 | 
			
		||||
        from .constant import ConstantVariable
 | 
			
		||||
        from .dicts import ConstDictVariable
 | 
			
		||||
        from .lists import BaseListVariable
 | 
			
		||||
 | 
			
		||||
        grid = self.grid
 | 
			
		||||
 | 
			
		||||
        if grid is None:
 | 
			
		||||
        if self.grid is None:
 | 
			
		||||
            raise Unsupported("Triton kernels should always be called with a grid")
 | 
			
		||||
 | 
			
		||||
        # Both for grid's meta as well as for the kernel, we need combined
 | 
			
		||||
        # args and kwargs normalized
 | 
			
		||||
        normalized_args = {**dict(zip(self.kernel.arg_names, args)), **kwargs}
 | 
			
		||||
        meta = ConstDictVariable(normalized_args, dict)
 | 
			
		||||
 | 
			
		||||
        # If the grid is a function, then lets execute it and convert it to
 | 
			
		||||
        # a list
 | 
			
		||||
        if isinstance(grid, (NestedUserFunctionVariable, UserFunctionVariable)):
 | 
			
		||||
            # Populate the special "meta" argument to call the grid function
 | 
			
		||||
            grid = grid.call_function(tx, [meta], {})
 | 
			
		||||
        configs = (
 | 
			
		||||
            [config.kwargs for config in self.kernel.configs]
 | 
			
		||||
            if isinstance(self.kernel, Autotuner)
 | 
			
		||||
            else [{}]
 | 
			
		||||
        )
 | 
			
		||||
        grids = []
 | 
			
		||||
        for config_args in configs:
 | 
			
		||||
            # If the grid is a function, then lets execute it and convert it to
 | 
			
		||||
            # a list
 | 
			
		||||
            grid = self.grid
 | 
			
		||||
            if isinstance(grid, (NestedUserFunctionVariable, UserFunctionVariable)):
 | 
			
		||||
                # Populate the special "meta" argument to call the grid function
 | 
			
		||||
                config_args = {
 | 
			
		||||
                    k: ConstantVariable.create(v) for k, v in config_args.items()
 | 
			
		||||
                }
 | 
			
		||||
                meta = ConstDictVariable({**normalized_args, **config_args}, dict)
 | 
			
		||||
                grid = grid.call_function(tx, [meta], {})
 | 
			
		||||
 | 
			
		||||
        # Now, the grid must be a list either originally or through above
 | 
			
		||||
        # modification
 | 
			
		||||
        if isinstance(grid, BaseListVariable):
 | 
			
		||||
            grid = grid.as_proxy()
 | 
			
		||||
        else:
 | 
			
		||||
            unimplemented(f"grid for the triton kernel is {type(grid)}")
 | 
			
		||||
            # Now, the grid must be a list either originally or through above
 | 
			
		||||
            # modification
 | 
			
		||||
            if isinstance(grid, BaseListVariable):
 | 
			
		||||
                grids.append(grid.as_proxy())
 | 
			
		||||
            else:
 | 
			
		||||
                unimplemented(f"grid for the triton kernel is {type(grid)}")
 | 
			
		||||
 | 
			
		||||
        for i in range(len(grids)):
 | 
			
		||||
            if not isinstance(grids[i], tuple):
 | 
			
		||||
                raise Unsupported("Only tuple grids are supported")
 | 
			
		||||
            # inductor expects all grids to be 3-tuple so lets make it
 | 
			
		||||
            if len(grids[i]) == 1:
 | 
			
		||||
                grids[i] = (grids[i][0], 1, 1)
 | 
			
		||||
            elif len(grids[i]) == 2:
 | 
			
		||||
                grids[i] = (grids[i][0], grids[i][1], 1)
 | 
			
		||||
            elif len(grids[i]) > 3:
 | 
			
		||||
                raise Unsupported("Grid can have at most rank 3")
 | 
			
		||||
 | 
			
		||||
        assert len(grids) != 0
 | 
			
		||||
        if len(set(grids)) == 1:
 | 
			
		||||
            # If there's only one unique grid, lets simplify
 | 
			
		||||
            grids = [grids[0]]
 | 
			
		||||
 | 
			
		||||
        from torch._higher_order_ops.triton_kernel_wrap import (
 | 
			
		||||
            triton_kernel_wrapper_mutation,
 | 
			
		||||
@ -722,13 +750,14 @@ class TritonKernelVariable(VariableTracker):
 | 
			
		||||
        # Combine args and kwargs and pass as a dict so that if user defined triton
 | 
			
		||||
        # kernel uses variables as 'grid' or 'kernel', it does not conflict with
 | 
			
		||||
        # parameters of the wrapper function
 | 
			
		||||
        meta = ConstDictVariable(normalized_args, dict)
 | 
			
		||||
        tx.output.create_proxy(
 | 
			
		||||
            "call_function",
 | 
			
		||||
            triton_kernel_wrapper_mutation,
 | 
			
		||||
            (),
 | 
			
		||||
            {
 | 
			
		||||
                "kernel_idx": self.kernel_idx,
 | 
			
		||||
                "grid": grid,
 | 
			
		||||
                "grid": grids,
 | 
			
		||||
                "kwargs": meta.as_proxy(),
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user