mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix clamp type promotion in inductor decomposition (#154471)
Summary: as title, the clamp type promotion should take min/max arg into consideration as well. Test Plan: ``` buck run fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_clamp_decomposition_cpu python test/inductor/test_torchinductor.py -k test_clamp -v ``` Differential Revision: D75490124 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154471 Approved by: https://github.com/desertfire, https://github.com/chenyang78
This commit is contained in:
committed by
PyTorch MergeBot
parent
d865b784e4
commit
3e05a48927
@ -5851,6 +5851,22 @@ class AOTInductorTestsTemplate:
|
||||
# compare against eager
|
||||
self.assertEqual(optimized(**model_kwargs), model(**model_kwargs))
|
||||
|
||||
def test_clamp_decomposition(self):
|
||||
class Model1(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x.clamp(min=1.5)
|
||||
|
||||
class Model2(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x.clamp(min=2)
|
||||
|
||||
x = torch.randint(4, (4,))
|
||||
|
||||
# the output should have float32 type, not int
|
||||
self.check_model(Model1(), (x,))
|
||||
# the output should have int type
|
||||
self.check_model(Model2(), (x,))
|
||||
|
||||
|
||||
class AOTInductorLoggingTest(LoggingTestCase):
|
||||
@make_logging_test(dynamic=logging.DEBUG)
|
||||
|
@ -2696,6 +2696,12 @@ class CommonTemplate:
|
||||
|
||||
self.common(fn, (torch.randint(4, (4,)),))
|
||||
|
||||
def test_clamp_type_promotion_non_tensor(self):
|
||||
def fn(a):
|
||||
return a.clamp(min=1.5), a.clamp(min=2)
|
||||
|
||||
self.common(fn, (torch.randint(4, (4,)),))
|
||||
|
||||
@skip_if_gpu_halide
|
||||
@xfail_if_triton_cpu
|
||||
def test_dist(self):
|
||||
|
@ -58,11 +58,17 @@ def type_casts(
|
||||
f: Callable,
|
||||
type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
compute_dtype_only: bool = False,
|
||||
include_non_tensor_args: bool = False,
|
||||
):
|
||||
@functools.wraps(f)
|
||||
def inner(*args, **kwargs):
|
||||
allowed_types = (
|
||||
(Tensor, torch.types._Number) if include_non_tensor_args else (Tensor,)
|
||||
) # type: ignore[arg-type]
|
||||
flat_args = [
|
||||
x for x in pytree.arg_tree_leaves(*args, **kwargs) if isinstance(x, Tensor)
|
||||
x
|
||||
for x in pytree.arg_tree_leaves(*args, **kwargs)
|
||||
if isinstance(x, allowed_types)
|
||||
]
|
||||
computation_dtype, result_dtype = utils.elementwise_dtypes(
|
||||
*flat_args, type_promotion_kind=type_promotion
|
||||
@ -98,6 +104,11 @@ compute_only_pw_cast_for_opmath = partial(
|
||||
pw_cast_for_opmath = partial(
|
||||
type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
||||
)
|
||||
pw_cast_for_opmath_non_tensor_args = partial(
|
||||
type_casts,
|
||||
type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
include_non_tensor_args=True,
|
||||
)
|
||||
pw_cast_for_int_to_real = partial(
|
||||
type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
|
||||
)
|
||||
|
@ -22,6 +22,7 @@ from torch._decomp.decompositions import (
|
||||
_index_add,
|
||||
embedding_dense_backward as decomp_embedding_dense_backward,
|
||||
pw_cast_for_opmath,
|
||||
pw_cast_for_opmath_non_tensor_args,
|
||||
)
|
||||
from torch._decomp.decompositions_for_rng import extra_random_decomps
|
||||
from torch._dynamo.utils import counters
|
||||
@ -181,7 +182,7 @@ def sym_constrain_range_for_size(
|
||||
|
||||
|
||||
@register_decomposition([aten.clamp])
|
||||
@pw_cast_for_opmath
|
||||
@pw_cast_for_opmath_non_tensor_args
|
||||
def clamp(
|
||||
x: torch.Tensor,
|
||||
min: Optional[torch.types.Number] = None,
|
||||
|
Reference in New Issue
Block a user