Fix clamp type promotion in inductor decomposition (#154471)

Summary: as title, the clamp type promotion should take min/max arg into consideration as well.

Test Plan:
```
buck run fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_clamp_decomposition_cpu
python test/inductor/test_torchinductor.py -k test_clamp -v
```

Differential Revision: D75490124

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154471
Approved by: https://github.com/desertfire, https://github.com/chenyang78
This commit is contained in:
Shangdi Yu
2025-05-28 23:24:25 +00:00
committed by PyTorch MergeBot
parent d865b784e4
commit 3e05a48927
4 changed files with 36 additions and 2 deletions

View File

@ -58,11 +58,17 @@ def type_casts(
f: Callable,
type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND,
compute_dtype_only: bool = False,
include_non_tensor_args: bool = False,
):
@functools.wraps(f)
def inner(*args, **kwargs):
allowed_types = (
(Tensor, torch.types._Number) if include_non_tensor_args else (Tensor,)
) # type: ignore[arg-type]
flat_args = [
x for x in pytree.arg_tree_leaves(*args, **kwargs) if isinstance(x, Tensor)
x
for x in pytree.arg_tree_leaves(*args, **kwargs)
if isinstance(x, allowed_types)
]
computation_dtype, result_dtype = utils.elementwise_dtypes(
*flat_args, type_promotion_kind=type_promotion
@ -98,6 +104,11 @@ compute_only_pw_cast_for_opmath = partial(
pw_cast_for_opmath = partial(
type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
pw_cast_for_opmath_non_tensor_args = partial(
type_casts,
type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
include_non_tensor_args=True,
)
pw_cast_for_int_to_real = partial(
type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
)