Fix clamp type promotion in inductor decomposition (#154471)

Summary: as title, the clamp type promotion should take min/max arg into consideration as well.

Test Plan:
```
buck run fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_clamp_decomposition_cpu
python test/inductor/test_torchinductor.py -k test_clamp -v
```

Differential Revision: D75490124

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154471
Approved by: https://github.com/desertfire, https://github.com/chenyang78
This commit is contained in:
Shangdi Yu
2025-05-28 23:24:25 +00:00
committed by PyTorch MergeBot
parent d865b784e4
commit 3e05a48927
4 changed files with 36 additions and 2 deletions

View File

@ -5851,6 +5851,22 @@ class AOTInductorTestsTemplate:
# compare against eager
self.assertEqual(optimized(**model_kwargs), model(**model_kwargs))
def test_clamp_decomposition(self):
class Model1(torch.nn.Module):
def forward(self, x):
return x.clamp(min=1.5)
class Model2(torch.nn.Module):
def forward(self, x):
return x.clamp(min=2)
x = torch.randint(4, (4,))
# the output should have float32 type, not int
self.check_model(Model1(), (x,))
# the output should have int type
self.check_model(Model2(), (x,))
class AOTInductorLoggingTest(LoggingTestCase):
@make_logging_test(dynamic=logging.DEBUG)

View File

@ -2696,6 +2696,12 @@ class CommonTemplate:
self.common(fn, (torch.randint(4, (4,)),))
def test_clamp_type_promotion_non_tensor(self):
def fn(a):
return a.clamp(min=1.5), a.clamp(min=2)
self.common(fn, (torch.randint(4, (4,)),))
@skip_if_gpu_halide
@xfail_if_triton_cpu
def test_dist(self):

View File

@ -58,11 +58,17 @@ def type_casts(
f: Callable,
type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND,
compute_dtype_only: bool = False,
include_non_tensor_args: bool = False,
):
@functools.wraps(f)
def inner(*args, **kwargs):
allowed_types = (
(Tensor, torch.types._Number) if include_non_tensor_args else (Tensor,)
) # type: ignore[arg-type]
flat_args = [
x for x in pytree.arg_tree_leaves(*args, **kwargs) if isinstance(x, Tensor)
x
for x in pytree.arg_tree_leaves(*args, **kwargs)
if isinstance(x, allowed_types)
]
computation_dtype, result_dtype = utils.elementwise_dtypes(
*flat_args, type_promotion_kind=type_promotion
@ -98,6 +104,11 @@ compute_only_pw_cast_for_opmath = partial(
pw_cast_for_opmath = partial(
type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
pw_cast_for_opmath_non_tensor_args = partial(
type_casts,
type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
include_non_tensor_args=True,
)
pw_cast_for_int_to_real = partial(
type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
)

View File

@ -22,6 +22,7 @@ from torch._decomp.decompositions import (
_index_add,
embedding_dense_backward as decomp_embedding_dense_backward,
pw_cast_for_opmath,
pw_cast_for_opmath_non_tensor_args,
)
from torch._decomp.decompositions_for_rng import extra_random_decomps
from torch._dynamo.utils import counters
@ -181,7 +182,7 @@ def sym_constrain_range_for_size(
@register_decomposition([aten.clamp])
@pw_cast_for_opmath
@pw_cast_for_opmath_non_tensor_args
def clamp(
x: torch.Tensor,
min: Optional[torch.types.Number] = None,