mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Fix clamp type promotion in inductor decomposition (#154471)
Summary: as title, the clamp type promotion should take min/max arg into consideration as well. Test Plan: ``` buck run fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_clamp_decomposition_cpu python test/inductor/test_torchinductor.py -k test_clamp -v ``` Differential Revision: D75490124 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154471 Approved by: https://github.com/desertfire, https://github.com/chenyang78
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							d865b784e4
						
					
				
				
					commit
					3e05a48927
				
			| @ -5851,6 +5851,22 @@ class AOTInductorTestsTemplate: | ||||
|         # compare against eager | ||||
|         self.assertEqual(optimized(**model_kwargs), model(**model_kwargs)) | ||||
|  | ||||
|     def test_clamp_decomposition(self): | ||||
|         class Model1(torch.nn.Module): | ||||
|             def forward(self, x): | ||||
|                 return x.clamp(min=1.5) | ||||
|  | ||||
|         class Model2(torch.nn.Module): | ||||
|             def forward(self, x): | ||||
|                 return x.clamp(min=2) | ||||
|  | ||||
|         x = torch.randint(4, (4,)) | ||||
|  | ||||
|         # the output should have float32 type, not int | ||||
|         self.check_model(Model1(), (x,)) | ||||
|         # the output should have int type | ||||
|         self.check_model(Model2(), (x,)) | ||||
|  | ||||
|  | ||||
| class AOTInductorLoggingTest(LoggingTestCase): | ||||
|     @make_logging_test(dynamic=logging.DEBUG) | ||||
|  | ||||
| @ -2696,6 +2696,12 @@ class CommonTemplate: | ||||
|  | ||||
|         self.common(fn, (torch.randint(4, (4,)),)) | ||||
|  | ||||
|     def test_clamp_type_promotion_non_tensor(self): | ||||
|         def fn(a): | ||||
|             return a.clamp(min=1.5), a.clamp(min=2) | ||||
|  | ||||
|         self.common(fn, (torch.randint(4, (4,)),)) | ||||
|  | ||||
|     @skip_if_gpu_halide | ||||
|     @xfail_if_triton_cpu | ||||
|     def test_dist(self): | ||||
|  | ||||
| @ -58,11 +58,17 @@ def type_casts( | ||||
|     f: Callable, | ||||
|     type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND, | ||||
|     compute_dtype_only: bool = False, | ||||
|     include_non_tensor_args: bool = False, | ||||
| ): | ||||
|     @functools.wraps(f) | ||||
|     def inner(*args, **kwargs): | ||||
|         allowed_types = ( | ||||
|             (Tensor, torch.types._Number) if include_non_tensor_args else (Tensor,) | ||||
|         )  # type: ignore[arg-type] | ||||
|         flat_args = [ | ||||
|             x for x in pytree.arg_tree_leaves(*args, **kwargs) if isinstance(x, Tensor) | ||||
|             x | ||||
|             for x in pytree.arg_tree_leaves(*args, **kwargs) | ||||
|             if isinstance(x, allowed_types) | ||||
|         ] | ||||
|         computation_dtype, result_dtype = utils.elementwise_dtypes( | ||||
|             *flat_args, type_promotion_kind=type_promotion | ||||
| @ -98,6 +104,11 @@ compute_only_pw_cast_for_opmath = partial( | ||||
| pw_cast_for_opmath = partial( | ||||
|     type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT | ||||
| ) | ||||
| pw_cast_for_opmath_non_tensor_args = partial( | ||||
|     type_casts, | ||||
|     type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, | ||||
|     include_non_tensor_args=True, | ||||
| ) | ||||
| pw_cast_for_int_to_real = partial( | ||||
|     type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT | ||||
| ) | ||||
|  | ||||
| @ -22,6 +22,7 @@ from torch._decomp.decompositions import ( | ||||
|     _index_add, | ||||
|     embedding_dense_backward as decomp_embedding_dense_backward, | ||||
|     pw_cast_for_opmath, | ||||
|     pw_cast_for_opmath_non_tensor_args, | ||||
| ) | ||||
| from torch._decomp.decompositions_for_rng import extra_random_decomps | ||||
| from torch._dynamo.utils import counters | ||||
| @ -181,7 +182,7 @@ def sym_constrain_range_for_size( | ||||
|  | ||||
|  | ||||
| @register_decomposition([aten.clamp]) | ||||
| @pw_cast_for_opmath | ||||
| @pw_cast_for_opmath_non_tensor_args | ||||
| def clamp( | ||||
|     x: torch.Tensor, | ||||
|     min: Optional[torch.types.Number] = None, | ||||
|  | ||||
		Reference in New Issue
	
	Block a user