mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[inductor][refactor] Unify the use of generate_kernel_call (#128467)
Summary: Refactor TritonTemplateKernel.call_kernel and ForeachKernel.call_kernel to use wrapper.generate_kernel_call to generate kernel calls instead of explicitly composing the kernel call string. This consolidates the entry point of generate_kernel_call and similifies later changes in this PR stack. Differential Revision: [D58733631](https://our.internmc.facebook.com/intern/diff/D58733631) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128467 Approved by: https://github.com/shunting314
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							3a185778ed
						
					
				
				
					commit
					ba92f5277f
				
			| @ -198,11 +198,12 @@ class MultiKernel: | ||||
|         for the multi-kernel. | ||||
|         """ | ||||
|         assert kernel_name == self.kernel_name | ||||
|         call_args_list, arg_types = zip( | ||||
|             *[kernel.get_call_args() for kernel in self.kernels] | ||||
|         ) | ||||
|         call_args_list = list(call_args_list) | ||||
|         arg_types_list = list(arg_types) | ||||
|         call_args_list = [] | ||||
|         arg_types_list = [] | ||||
|         for kernel in self.kernels: | ||||
|             _, call_args, _, arg_types = kernel.args.python_argdefs() | ||||
|             call_args_list.append(call_args) | ||||
|             arg_types_list.append(arg_types) | ||||
|  | ||||
|         all_call_args, arg_types = get_all_call_args(call_args_list, arg_types_list) | ||||
|         grid: List[Any] = [] | ||||
| @ -223,12 +224,10 @@ class MultiKernel: | ||||
|         ) | ||||
|  | ||||
|         grid = V.graph.wrapper_code.generate_default_grid(kernel_name, grid) | ||||
|         current_device = V.graph.scheduler.get_current_device_or_throw() | ||||
|         V.graph.wrapper_code.generate_kernel_call( | ||||
|             kernel_name, | ||||
|             final_call_args, | ||||
|             grid, | ||||
|             current_device.index, | ||||
|             arg_types=arg_types, | ||||
|         ) | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user