mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	Fixes #109605 Generated code before: ``` def call(args): arg0_1, = args args.clear() assert_size_stride(arg0_1, (8, ), (1, )) buf0 = empty_strided((), (), device='cpu', dtype=torch.int64) cpp_fused_lift_fresh_0(c_void_p(buf0.data_ptr())) # Source Nodes: [wrapped_pow], Original ATen: [aten.lift_fresh, aten.pow] buf1 = aten.pow(arg0_1, reinterpret_tensor(buf0, (8, ), (0, ), 0)) del arg0_1 del buf0 buf2 = buf1 assert_size_stride(buf2, (8, ), (1, )) del buf1 return (buf2, ) ``` Generated code now: ``` def call(args): arg0_1, = args args.clear() assert_size_stride(arg0_1, (8, ), (1, )) buf0 = empty_strided((8, ), (1, ), device='cpu', dtype=torch.int64) cpp_fused_pow_0(c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr())) del arg0_1 return (buf0, ) ``` @lezcano What would be a good way to add a test for this? Pull Request resolved: https://github.com/pytorch/pytorch/pull/109953 Approved by: https://github.com/lezcano
		
			
				
	
	
		
			333 lines
		
	
	
		
			8.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			333 lines
		
	
	
		
			8.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from __future__ import annotations
 | |
| 
 | |
| from typing import Optional
 | |
| 
 | |
| import torch
 | |
| 
 | |
| from . import _binary_ufuncs_impl, _dtypes_impl, _unary_ufuncs_impl, _util
 | |
| from ._normalizations import (
 | |
|     ArrayLike,
 | |
|     ArrayLikeOrScalar,
 | |
|     CastingModes,
 | |
|     DTypeLike,
 | |
|     normalizer,
 | |
|     NotImplementedType,
 | |
|     OutArray,
 | |
| )
 | |
| 
 | |
| 
 | |
| def _ufunc_postprocess(result, out, casting):
 | |
|     if out is not None:
 | |
|         result = _util.typecast_tensor(result, out.dtype.torch_dtype, casting)
 | |
|         result = torch.broadcast_to(result, out.shape)
 | |
|     return result
 | |
| 
 | |
| 
 | |
| # ############# Binary ufuncs ######################
 | |
| 
 | |
| _binary = [
 | |
|     name
 | |
|     for name in dir(_binary_ufuncs_impl)
 | |
|     if not name.startswith("_") and name not in ["torch", "matmul", "divmod", "ldexp"]
 | |
| ]
 | |
| 
 | |
| 
 | |
| NEP50_FUNCS = (
 | |
|     "add",
 | |
|     "subtract",
 | |
|     "multiply",
 | |
|     "floor_divide",
 | |
|     "true_divide",
 | |
|     "divide",
 | |
|     "remainder",
 | |
|     "bitwise_and",
 | |
|     "bitwise_or",
 | |
|     "bitwise_xor",
 | |
|     "bitwise_left_shift",
 | |
|     "bitwise_right_shift",
 | |
|     "hypot",
 | |
|     "arctan2",
 | |
|     "logaddexp",
 | |
|     "logaddexp2",
 | |
|     "heaviside",
 | |
|     "copysign",
 | |
|     "fmax",
 | |
|     "minimum",
 | |
|     "fmin",
 | |
|     "maximum",
 | |
|     "fmod",
 | |
|     "gcd",
 | |
|     "lcm",
 | |
|     "pow",
 | |
| )
 | |
| 
 | |
| 
 | |
| def deco_binary_ufunc(torch_func):
 | |
|     """Common infra for binary ufuncs.
 | |
| 
 | |
|     Normalize arguments, sort out type casting, broadcasting and delegate to
 | |
|     the pytorch functions for the actual work.
 | |
|     """
 | |
| 
 | |
|     @normalizer
 | |
|     def wrapped(
 | |
|         x1: ArrayLikeOrScalar,
 | |
|         x2: ArrayLikeOrScalar,
 | |
|         /,
 | |
|         out: Optional[OutArray] = None,
 | |
|         *,
 | |
|         where: NotImplementedType = True,
 | |
|         casting: Optional[CastingModes] = "same_kind",
 | |
|         order: NotImplementedType = "K",
 | |
|         dtype: Optional[DTypeLike] = None,
 | |
|         subok: NotImplementedType = False,
 | |
|         signature: NotImplementedType = None,
 | |
|         extobj: NotImplementedType = None,
 | |
|     ):
 | |
|         if dtype is not None:
 | |
| 
 | |
|             def cast(x, dtype):
 | |
|                 if isinstance(x, torch.Tensor):
 | |
|                     return _util.typecast_tensor(x, dtype, casting)
 | |
|                 else:
 | |
|                     return torch.as_tensor(x, dtype=dtype)
 | |
| 
 | |
|             x1 = cast(x1, dtype)
 | |
|             x2 = cast(x2, dtype)
 | |
|         elif isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor):
 | |
|             dtype = _dtypes_impl.result_type_impl(x1, x2)
 | |
|             x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
 | |
|         else:
 | |
|             x1, x2 = _dtypes_impl.nep50_to_tensors(
 | |
|                 x1, x2, torch_func.__name__ in NEP50_FUNCS, torch_func.__name__
 | |
|             )
 | |
| 
 | |
|         result = torch_func(x1, x2)
 | |
| 
 | |
|         return _ufunc_postprocess(result, out, casting)
 | |
| 
 | |
|     wrapped.__qualname__ = torch_func.__name__
 | |
|     wrapped.__name__ = torch_func.__name__
 | |
| 
 | |
|     return wrapped
 | |
| 
 | |
| 
 | |
| # matmul's signature is _slightly_ different from other ufuncs:
 | |
| # - no where=...
 | |
| # - additional axis=..., axes=...
 | |
| # - no NEP50 scalars in or out
 | |
| @normalizer
 | |
| def matmul(
 | |
|     x1: ArrayLike,
 | |
|     x2: ArrayLike,
 | |
|     /,
 | |
|     out: Optional[OutArray] = None,
 | |
|     *,
 | |
|     casting: Optional[CastingModes] = "same_kind",
 | |
|     order: NotImplementedType = "K",
 | |
|     dtype: Optional[DTypeLike] = None,
 | |
|     subok: NotImplementedType = False,
 | |
|     signature: NotImplementedType = None,
 | |
|     extobj: NotImplementedType = None,
 | |
|     axes: NotImplementedType = None,
 | |
|     axis: NotImplementedType = None,
 | |
| ):
 | |
|     if dtype is None:
 | |
|         dtype = _dtypes_impl.result_type_impl(x1, x2)
 | |
|     x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
 | |
| 
 | |
|     result = _binary_ufuncs_impl.matmul(x1, x2)
 | |
| 
 | |
|     result = _ufunc_postprocess(result, out, casting)
 | |
|     return result
 | |
| 
 | |
| 
 | |
| # ldexp casting is special : the dtype of the result == dtype of the 1st arg
 | |
| @normalizer
 | |
| def ldexp(
 | |
|     x1: ArrayLikeOrScalar,
 | |
|     x2: ArrayLikeOrScalar,
 | |
|     /,
 | |
|     out: Optional[OutArray] = None,
 | |
|     *,
 | |
|     where: NotImplementedType = True,
 | |
|     casting: Optional[CastingModes] = "same_kind",
 | |
|     order: NotImplementedType = "K",
 | |
|     dtype: Optional[DTypeLike] = None,
 | |
|     subok: NotImplementedType = False,
 | |
|     signature: NotImplementedType = None,
 | |
|     extobj: NotImplementedType = None,
 | |
| ):
 | |
|     if dtype is not None:
 | |
|         if isinstance(x1, torch.Tensor):
 | |
|             x1 = _util.typecast_tensor(x1, dtype, casting)
 | |
|         else:
 | |
|             x1 = torch.as_tensor(x1, dtype=dtype)
 | |
|     else:
 | |
|         if not isinstance(x1, torch.Tensor):
 | |
|             x1 = torch.as_tensor(x1)
 | |
|             x1 = _util.cast_int_to_float(x1)
 | |
| 
 | |
|     x2 = torch.as_tensor(x2)
 | |
|     # the second arg must be integer
 | |
|     if _dtypes_impl._category(x2.dtype) != 1:
 | |
|         raise ValueError("ldexp 2nd arg must be integer")
 | |
| 
 | |
|     result = _binary_ufuncs_impl.ldexp(x1, x2)
 | |
| 
 | |
|     if x1.dtype == torch.float16:
 | |
|         # torch.ldexp(f16, int) -> f32, undo it
 | |
|         result = result.to(torch.float16)
 | |
| 
 | |
|     return _ufunc_postprocess(result, out, casting)
 | |
| 
 | |
| 
 | |
| # nin=2, nout=2
 | |
| @normalizer
 | |
| def divmod(
 | |
|     x1: ArrayLike,
 | |
|     x2: ArrayLike,
 | |
|     out1: Optional[OutArray] = None,
 | |
|     out2: Optional[OutArray] = None,
 | |
|     /,
 | |
|     out: tuple[Optional[OutArray], Optional[OutArray]] = (None, None),
 | |
|     *,
 | |
|     where: NotImplementedType = True,
 | |
|     casting: Optional[CastingModes] = "same_kind",
 | |
|     order: NotImplementedType = "K",
 | |
|     dtype: Optional[DTypeLike] = None,
 | |
|     subok: NotImplementedType = False,
 | |
|     signature: NotImplementedType = None,
 | |
|     extobj: NotImplementedType = None,
 | |
| ):
 | |
|     # make sure we either have no out arrays at all, or there is either
 | |
|     # out1, out2, or out=tuple, but not both
 | |
|     num_outs = sum(x is not None for x in [out1, out2])
 | |
|     if num_outs == 1:
 | |
|         raise ValueError("both out1 and out2 need to be provided")
 | |
|     elif num_outs == 2:
 | |
|         o1, o2 = out
 | |
|         if o1 is not None or o2 is not None:
 | |
|             raise TypeError(
 | |
|                 "cannot specify 'out' as both a positional and keyword argument"
 | |
|             )
 | |
|     else:
 | |
|         out1, out2 = out
 | |
| 
 | |
|     if dtype is None:
 | |
|         dtype = _dtypes_impl.result_type_impl(x1, x2)
 | |
|     x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
 | |
| 
 | |
|     quot, rem = _binary_ufuncs_impl.divmod(x1, x2)
 | |
| 
 | |
|     quot = _ufunc_postprocess(quot, out1, casting)
 | |
|     rem = _ufunc_postprocess(rem, out2, casting)
 | |
|     return quot, rem
 | |
| 
 | |
| 
 | |
| #
 | |
| # Attach ufuncs to this module, for a further export to the public namespace in __init__.py
 | |
| #
 | |
| for name in _binary:
 | |
|     ufunc = getattr(_binary_ufuncs_impl, name)
 | |
|     vars()[name] = deco_binary_ufunc(ufunc)
 | |
| 
 | |
| 
 | |
| def modf(x, /, *args, **kwds):
 | |
|     quot, rem = divmod(x, 1, *args, **kwds)
 | |
|     return rem, quot
 | |
| 
 | |
| 
 | |
| _binary = _binary + ["divmod", "modf", "matmul", "ldexp"]
 | |
| 
 | |
| 
 | |
| # ############# Unary ufuncs ######################
 | |
| 
 | |
| 
 | |
| _unary = [
 | |
|     name
 | |
|     for name in dir(_unary_ufuncs_impl)
 | |
|     if not name.startswith("_") and name != "torch"
 | |
| ]
 | |
| 
 | |
| 
 | |
| # these are ufunc(int) -> float
 | |
| _fp_unary = [
 | |
|     "arccos",
 | |
|     "arccosh",
 | |
|     "arcsin",
 | |
|     "arcsinh",
 | |
|     "arctan",
 | |
|     "arctanh",
 | |
|     "cbrt",
 | |
|     "cos",
 | |
|     "cosh",
 | |
|     "deg2rad",
 | |
|     "degrees",
 | |
|     "exp",
 | |
|     "exp2",
 | |
|     "expm1",
 | |
|     "log",
 | |
|     "log10",
 | |
|     "log1p",
 | |
|     "log2",
 | |
|     "rad2deg",
 | |
|     "radians",
 | |
|     "reciprocal",
 | |
|     "sin",
 | |
|     "sinh",
 | |
|     "sqrt",
 | |
|     "square",
 | |
|     "tan",
 | |
|     "tanh",
 | |
|     "trunc",
 | |
| ]
 | |
| 
 | |
| 
 | |
| def deco_unary_ufunc(torch_func):
 | |
|     """Common infra for unary ufuncs.
 | |
| 
 | |
|     Normalize arguments, sort out type casting, broadcasting and delegate to
 | |
|     the pytorch functions for the actual work.
 | |
|     """
 | |
| 
 | |
|     @normalizer
 | |
|     def wrapped(
 | |
|         x: ArrayLike,
 | |
|         /,
 | |
|         out: Optional[OutArray] = None,
 | |
|         *,
 | |
|         where=True,
 | |
|         casting: Optional[CastingModes] = "same_kind",
 | |
|         order="K",
 | |
|         dtype: Optional[DTypeLike] = None,
 | |
|         subok: NotImplementedType = False,
 | |
|         signature=None,
 | |
|         extobj=None,
 | |
|     ):
 | |
|         if dtype is not None:
 | |
|             x = _util.typecast_tensor(x, dtype, casting)
 | |
| 
 | |
|         if torch_func.__name__ in _fp_unary:
 | |
|             x = _util.cast_int_to_float(x)
 | |
| 
 | |
|         result = torch_func(x)
 | |
|         result = _ufunc_postprocess(result, out, casting)
 | |
|         return result
 | |
| 
 | |
|     wrapped.__qualname__ = torch_func.__name__
 | |
|     wrapped.__name__ = torch_func.__name__
 | |
| 
 | |
|     return wrapped
 | |
| 
 | |
| 
 | |
| #
 | |
| # Attach ufuncs to this module, for a further export to the public namespace in __init__.py
 | |
| #
 | |
| for name in _unary:
 | |
|     ufunc = getattr(_unary_ufuncs_impl, name)
 | |
|     vars()[name] = deco_unary_ufunc(ufunc)
 | |
| 
 | |
| 
 | |
| __all__ = _binary + _unary  # noqa: PLE0605
 |