mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[Inductor] Add CPU_MAX_FIRST_DIMENSION_DECOMPOSITION and CPU_MAX_OTHER_DIMENSION_DECOMPOSITION for decompose_mm_pass (#158183)
Differential Revision: D78209993 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158183 Approved by: https://github.com/houseroad
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							1b389025ba
						
					
				
				
					commit
					7f9fc7e67c
				
			| @ -15,11 +15,17 @@ aten = torch.ops.aten | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
| # TODO: need a better strategy for decomposing mm | ||||
| # The following two constants are for CUDA device only | ||||
| MIN_FIRST_DIMENSION_DECOMPOSITION = 10240 | ||||
| MAX_OTHER_DIMENSION_DECOMPOSITION = 32 | ||||
| # The following two constants are for CPU device only | ||||
| CPU_MAX_FIRST_DIMENSION_DECOMPOSITION = 1 | ||||
| CPU_MAX_OTHER_DIMENSION_DECOMPOSITION = 2048 | ||||
|  | ||||
| min_first_dimension_decomposition = MIN_FIRST_DIMENSION_DECOMPOSITION | ||||
| max_other_dimension_decomposition = MAX_OTHER_DIMENSION_DECOMPOSITION | ||||
| cpu_max_first_dimension_decomposition = CPU_MAX_FIRST_DIMENSION_DECOMPOSITION | ||||
| cpu_max_other_dimension_decomposition = CPU_MAX_OTHER_DIMENSION_DECOMPOSITION | ||||
| if "decompose_mm_pass" in config.post_grad_fusion_options: | ||||
|     min_first_dimension_decomposition = config.post_grad_fusion_options[ | ||||
|         "decompose_mm_pass" | ||||
| @ -27,6 +33,16 @@ if "decompose_mm_pass" in config.post_grad_fusion_options: | ||||
|     max_other_dimension_decomposition = config.post_grad_fusion_options[ | ||||
|         "decompose_mm_pass" | ||||
|     ].get("max_other_dimension_decomposition", MAX_OTHER_DIMENSION_DECOMPOSITION) | ||||
|     cpu_max_first_dimension_decomposition = config.post_grad_fusion_options[ | ||||
|         "decompose_mm_pass" | ||||
|     ].get( | ||||
|         "cpu_max_first_dimension_decomposition", CPU_MAX_FIRST_DIMENSION_DECOMPOSITION | ||||
|     ) | ||||
|     cpu_max_other_dimension_decomposition = config.post_grad_fusion_options[ | ||||
|         "decompose_mm_pass" | ||||
|     ].get( | ||||
|         "cpu_max_other_dimension_decomposition", CPU_MAX_OTHER_DIMENSION_DECOMPOSITION | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def check_device(a: Tensor, b: Tensor, device="cuda") -> bool: | ||||
| @ -57,7 +73,10 @@ def should_decompose_bmm(mat1, mat2) -> bool: | ||||
|             return False | ||||
|         return True | ||||
|     elif check_device(mat1, mat2, device="cpu"): | ||||
|         if mat1.shape[0] == 1 and mat2.shape[0] == 1: | ||||
|         if ( | ||||
|             mat1.shape[0] <= cpu_max_first_dimension_decomposition | ||||
|             and mat2.shape[0] <= cpu_max_first_dimension_decomposition | ||||
|         ): | ||||
|             return True | ||||
|     return False | ||||
|  | ||||
| @ -77,9 +96,15 @@ def should_decompose_mm(mat1, mat2) -> bool: | ||||
|         and statically_known_true(mat2.shape[1] < max_other_dimension_decomposition) | ||||
|     ) or ( | ||||
|         check_device(mat1, mat2, device="cpu") | ||||
|         and statically_known_true(mat1.shape[0] == 1) | ||||
|         and statically_known_true(mat2.shape[0] <= 128) | ||||
|         and statically_known_true(mat2.shape[1] <= 512) | ||||
|         and statically_known_true( | ||||
|             mat1.shape[0] <= cpu_max_first_dimension_decomposition | ||||
|         ) | ||||
|         and statically_known_true( | ||||
|             mat2.shape[0] <= cpu_max_other_dimension_decomposition | ||||
|         ) | ||||
|         and statically_known_true( | ||||
|             mat2.shape[1] <= cpu_max_other_dimension_decomposition | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user