mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix clamp type promotion in inductor decomposition (#154471)
Summary: as title, the clamp type promotion should take min/max arg into consideration as well. Test Plan: ``` buck run fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_clamp_decomposition_cpu python test/inductor/test_torchinductor.py -k test_clamp -v ``` Differential Revision: D75490124 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154471 Approved by: https://github.com/desertfire, https://github.com/chenyang78
This commit is contained in:
committed by
PyTorch MergeBot
parent
d865b784e4
commit
3e05a48927
@ -58,11 +58,17 @@ def type_casts(
|
||||
f: Callable,
|
||||
type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
compute_dtype_only: bool = False,
|
||||
include_non_tensor_args: bool = False,
|
||||
):
|
||||
@functools.wraps(f)
|
||||
def inner(*args, **kwargs):
|
||||
allowed_types = (
|
||||
(Tensor, torch.types._Number) if include_non_tensor_args else (Tensor,)
|
||||
) # type: ignore[arg-type]
|
||||
flat_args = [
|
||||
x for x in pytree.arg_tree_leaves(*args, **kwargs) if isinstance(x, Tensor)
|
||||
x
|
||||
for x in pytree.arg_tree_leaves(*args, **kwargs)
|
||||
if isinstance(x, allowed_types)
|
||||
]
|
||||
computation_dtype, result_dtype = utils.elementwise_dtypes(
|
||||
*flat_args, type_promotion_kind=type_promotion
|
||||
@ -98,6 +104,11 @@ compute_only_pw_cast_for_opmath = partial(
|
||||
pw_cast_for_opmath = partial(
|
||||
type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
||||
)
|
||||
pw_cast_for_opmath_non_tensor_args = partial(
|
||||
type_casts,
|
||||
type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
include_non_tensor_args=True,
|
||||
)
|
||||
pw_cast_for_int_to_real = partial(
|
||||
type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
|
||||
)
|
||||
|
Reference in New Issue
Block a user