mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 20:34:54 +08:00 
			
		
		
		
	[inductor] Move LoopBody to its own file (#135257)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135257 Approved by: https://github.com/oulgen
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							18479c5f70
						
					
				
				
					commit
					eac5e12548
				
			| @ -16,6 +16,7 @@ from torch.utils._sympy.symbol import symbol_is_type, SymT | ||||
| from torch.utils._sympy.value_ranges import ValueRanges | ||||
|  | ||||
| from .. import ir | ||||
| from ..loop_body import LoopBody | ||||
| from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs | ||||
| from ..virtualized import ops, OpsValue, V | ||||
| from .common import ( | ||||
| @ -883,20 +884,19 @@ def create_epilogue_with_attr(input_buffer, attr, **kwargs): | ||||
|  | ||||
|  | ||||
| def _get_loop_body(fn_list): | ||||
|     loop_bodies = None | ||||
|     if all(isinstance(fn, ir.LoopBody) for fn in fn_list): | ||||
|     if all(isinstance(fn, LoopBody) for fn in fn_list): | ||||
|         loop_bodies = fn_list | ||||
|     else: | ||||
|         if hasattr(fn_list[0], "original_fn"): | ||||
|             # For the case of local buffer, we wrap the fn with localize_function | ||||
|             assert all(hasattr(fn, "original_fn") for fn in fn_list) | ||||
|             assert all( | ||||
|                 isinstance(fn.original_fn.args[0]._body, ir.LoopBody) for fn in fn_list | ||||
|                 isinstance(fn.original_fn.args[0]._body, LoopBody) for fn in fn_list | ||||
|             ) | ||||
|             loop_bodies = [fn.original_fn.args[0]._body for fn in fn_list] | ||||
|         else: | ||||
|             assert all(isinstance(fn, functools.partial) for fn in fn_list) | ||||
|             assert all(isinstance(fn.args[0]._body, ir.LoopBody) for fn in fn_list) | ||||
|             assert all(isinstance(fn.args[0]._body, LoopBody) for fn in fn_list) | ||||
|             loop_bodies = [fn.args[0]._body for fn in fn_list] | ||||
|     assert loop_bodies is not None | ||||
|     return loop_bodies | ||||
|  | ||||
		Reference in New Issue
	
	Block a user