add decomposition for frexp (#119217)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119217
Approved by: https://github.com/peterbell10
ghstack dependencies: #119284, #120027
This commit is contained in:
Isuru Fernando
2024-02-23 18:06:21 +00:00
committed by PyTorch MergeBot
parent f7e79299c7
commit b7df3bba62
22 changed files with 198 additions and 36 deletions

View File

@ -238,6 +238,8 @@ aten::fmod.Tensor
aten::fmod.Tensor_out
aten::fmod_.Scalar
aten::fmod_.Tensor
aten::frexp.Tensor
aten::frexp.Tensor_out
aten::full
aten::full.out
aten::gcd

View File

@ -773,8 +773,6 @@ aten::fractional_max_pool3d
aten::fractional_max_pool3d.output
aten::fractional_max_pool3d_backward
aten::fractional_max_pool3d_backward.grad_input
aten::frexp.Tensor
aten::frexp.Tensor_out
aten::from_file
aten::from_file.out
aten::full.names

View File

@ -4381,7 +4381,6 @@ aot_autograd_failures = {
symbolic_aot_autograd_failures = {
xfail('combinations', ''), # aten.masked_select.default
xfail('frexp', ''), # aten.frexp.Tensor - couldn't find symbolic meta function/decomposition
xfail('index_fill', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('linalg.eigvals', ''), # aten.linalg_eig.default - couldn't find symbolic meta function/decomposition

View File

@ -1663,6 +1663,7 @@ class CPUReproTests(TestCase):
"index_expr",
"signbit",
"isinf",
"frexp",
"mod",
"masked",
"randn",

View File

@ -641,7 +641,6 @@ meta_function_expected_failures = {
torch.Tensor.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
torch.Tensor.item : {f64, i32, c128, i64, i16, f16, u8, c32, c64, bf16, b8, i8, f32},
torch.bincount : {i32, i64, u8, i16, i8},
torch.frexp : {f64, f16, bf16, f32},
torch.functional.unique : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32},
torch.functional.unique_consecutive : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32},
torch.histc : {f64, f16, bf16, f32},
@ -815,7 +814,6 @@ meta_dispatch_expected_failures = {
aten._unique2.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8},
aten.bincount.default : {i64, i8, i32, i16, u8},
aten.equal.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
aten.frexp.Tensor : {bf16, f32, f16, f64},
aten.histc.default : {bf16, f32, f64},
aten.histc.out : {bf16, f32, f64},
aten.histogram.bin_ct : {f32, f64},

View File

@ -1978,7 +1978,12 @@ class TestRefsOpsInfo(TestCase):
'_refs.imag',
'_refs.reshape_as',
'_refs.view_as',
'_refs.view_as_complex' # TorchInductor does not support complex at the moment.
'_refs.view_as_complex', # TorchInductor does not support complex at the moment.
# the decompositions for these ops are slightly different
# because of out handling
'_refs.var_mean',
'_refs.std_mean',
'_refs.native_layer_norm',
}
@parametrize("op", ref_ops_names)

View File

@ -1876,7 +1876,6 @@ symbolic_tensor_failures = {
xfail('linalg.eig'),
xfail('linalg.eigvals'),
xfail('combinations', ''),
xfail('frexp', ''), # aten.frexp.Tensor - couldn't find symbolic meta function/decomposition
xfail('geqrf', ''), # aten.geqrf.default - couldn't find symbolic meta function/decomposition
xfail('histc', ''), # Could not run 'aten::histc' with arguments from the 'Meta' backend. This could be because...
xfail('histogram', ''), # Could not run 'aten::histogram.bin_ct' with arguments from the 'Meta' backend. This c...

View File

@ -20,6 +20,7 @@ from torch.testing._internal.common_utils import (
skipIfNoSciPy,
IS_WINDOWS,
gradcheck,
is_iterable_of_tensors,
)
from torch.testing._internal.common_methods_invocations import (
unary_ufuncs,
@ -39,6 +40,7 @@ from torch.testing._internal.common_device_type import (
precisionOverride,
dtypesIfCPU,
)
from torch.utils import _pytree as pytree
from torch.testing import make_tensor
from torch.testing._internal.common_dtype import (
@ -318,9 +320,10 @@ class TestUnaryUfuncs(TestCase):
self.assertFalse(non_contig.is_contiguous())
torch_kwargs, _ = op.sample_kwargs(device, dtype, non_contig)
self.assertEqual(
op(contig, **torch_kwargs)[::2], op(non_contig, **torch_kwargs)
)
expected = op(non_contig, **torch_kwargs)
result = op(contig, **torch_kwargs)
result = pytree.tree_map(lambda x: x[::2], result)
self.assertEqual(result, expected)
@ops(unary_ufuncs)
def test_contig_vs_transposed(self, device, dtype, op):
@ -333,7 +336,10 @@ class TestUnaryUfuncs(TestCase):
self.assertFalse(non_contig.is_contiguous())
torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
self.assertEqual(op(contig, **torch_kwargs).T, op(non_contig, **torch_kwargs))
expected = op(non_contig, **torch_kwargs)
result = op(contig, **torch_kwargs)
result = pytree.tree_map(lambda x: x.T, result)
self.assertEqual(result, expected)
@ops(unary_ufuncs)
def test_non_contig(self, device, dtype, op):
@ -385,8 +391,9 @@ class TestUnaryUfuncs(TestCase):
contig = op(contig, **torch_kwargs)
non_contig = op(non_contig, **torch_kwargs)
for i in range(3):
non_contig_i = pytree.tree_map(lambda x: x[i], non_contig)
self.assertEqual(
contig, non_contig[i], msg="non-contiguous expand[" + str(i) + "]"
contig, non_contig_i, msg="non-contiguous expand[" + str(i) + "]"
)
@ops(unary_ufuncs)
@ -433,7 +440,12 @@ class TestUnaryUfuncs(TestCase):
torch_kwargs, _ = op.sample_kwargs(device, dtype, input)
actual = op(input, **torch_kwargs)
expected = torch.stack([op(slice, **torch_kwargs) for slice in input])
all_outs = [op(slice, **torch_kwargs) for slice in input]
if is_iterable_of_tensors(actual):
expected = [torch.stack([out[i] for out in all_outs]) for i in range(len(actual))]
else:
expected = torch.stack(all_outs)
self.assertEqual(actual, expected)

View File

@ -171,6 +171,7 @@ def register_decomposition(
assert type in {"post_autograd", "pre_autograd", "meta"}
def decomposition_decorator(fn: Callable) -> Callable:
orig_fn = fn
if not unsafe:
fn = _convert_out_params(fn)
@ -183,7 +184,7 @@ def register_decomposition(
# To handle allowing multiple aten_ops at once
pytree.tree_map_(register, aten_op)
return fn
return orig_fn
return decomposition_decorator

View File

@ -676,6 +676,17 @@ Either create the tensor outside the compiled region, or do not set the tensor t
for idx, name in enumerate(output_tensor_names):
if name in tx.symbolic_locals:
tx.symbolic_locals[name] = tensor_variable.items[idx]
for out_tensor, result_tensor in zip(
kwargs["out"].items, tensor_variable.items
):
if (
out_tensor.source
and out_tensor in tx.output.graphargs
and out_tensor.size != result_tensor.size
):
# It's hard to get out variants with resizing on graph inputs work
# properly across dynamo/aot/inductor, just fall back.
unimplemented("out variants with resizing on graph inputs")
elif isinstance(tensor_variable, TensorVariable):
assert isinstance(kwargs["out"], TensorVariable)
if (

View File

@ -26,6 +26,7 @@ from sympy.printing.printer import Printer
import torch
import torch.fx
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.utils import _pytree as pytree
from torch.utils._sympy.value_ranges import ValueRanges
from .. import config, metrics
@ -1441,13 +1442,14 @@ class Kernel(CodeGen):
fx_node, ValueRanges.unknown()
)
csevar = self.cse.generate(
self.compute,
getattr(parent_handler, name)(*args, **kwargs), # type: ignore[has-type]
bounds=buf_bounds,
)
csevar.update_on_args(name, args, kwargs)
return csevar
value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
def do_cse(v):
csevar = self.cse.generate(self.compute, v, bounds=buf_bounds)
csevar.update_on_args(name, args, kwargs)
return csevar
return pytree.tree_map(do_cse, value)
return inner

View File

@ -16,6 +16,7 @@ import torch.fx
from torch._inductor import dependencies
from torch._inductor.ir import StorageBox, TensorBox
from torch._prims_common import is_float_dtype
from torch.utils import _pytree as pytree
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
@ -804,6 +805,23 @@ class CppOverrides(OpOverrides):
def copysign(x, y):
return f"std::copysign({x}, {y})"
@staticmethod
def frexp(x):
cache_keys = f"frexp({x})[0]", f"frexp({x})[1]"
if all(cache_key in V.kernel.cse.cache for cache_key in cache_keys):
return tuple(V.kernel.cse.cache[cache_key] for cache_key in cache_keys)
code = BracesBuffer()
exponent = V.kernel.cse.newvar()
mantissa = V.kernel.cse.newvar()
code.writeline(f"int32_t {exponent};")
code.writeline(f"auto {mantissa} = std::frexp({x}, &{exponent});")
V.kernel.compute.splice(code)
cse_vars = (mantissa, exponent)
for cache_key, cse_var in zip(cache_keys, cse_vars):
V.kernel.cse.cache[cache_key] = cse_var
return mantissa, exponent
@staticmethod
def hypot(x, y):
return f"std::hypot({x}, {y})"
@ -2760,6 +2778,8 @@ class CppVecKernelChecker(CppVecKernel):
self._orig_wrapper_code = V.graph.wrapper_code
V.graph.wrapper_code = WrapperCodeGen()
parent_handler = V.MockHandler()
class VecCheckerProxy:
bin_cmp_ops = ["eq", "ne", "le", "ge", "lt", "gt"]
@ -2778,7 +2798,9 @@ class CppVecKernelChecker(CppVecKernel):
if name not in self.fast_vec_list:
self.disable_vec(f"op: {name}")
return self.simd_vec
parent_val = getattr(parent_handler, name)(*args, **kwargs)
return pytree.tree_map(lambda _: self.simd_vec, parent_val)
return inner

View File

@ -905,6 +905,20 @@ class TritonKernelOverrides(TritonOverrides):
f"tl.load({var} + {V.kernel.args.seed_offset('load_seed_offset', offset)})"
)
@staticmethod
def frexp(x):
cache_key = f"frexp({x})"
if cache_key in V.kernel.cse.cache:
return V.kernel.cse.cache[cache_key]
mantissa = V.kernel.cse.newvar()
exponent = V.kernel.cse.newvar()
V.kernel.compute.writeline(
f"{mantissa}, {exponent} = triton_helpers.frexp({x})"
)
V.kernel.cse.cache[cache_key] = (mantissa, exponent)
return (mantissa, exponent)
# Use mypy to check protocol implemented correctly
def _typecheck_TritonKernelOverrides(h: TritonKernelOverrides) -> OpsHandler[str]:

View File

@ -472,6 +472,9 @@ class FreeUnbackedSymbolsOpsHandler:
self.symbols |= free_unbacked_symbols(size)
return sympy_index_symbol(f"({str(index_var)})")
def frexp(self, x):
return (None,) * 2
def reduction(
self,
dtype: torch.dtype,

View File

@ -7039,6 +7039,11 @@ class LoopBodyBlock:
"call_module", name, (dtype_proxy, value_proxy, init_proxy), {}
)
def frexp(self, value_proxy):
result = self._inner.frexp(value_proxy)
# Proxies are iterable, but some methods expect tuples/lists
return (result[0], result[1])
@staticmethod
def indirect_indexing(index_proxy, size, check=True):
"""

View File

@ -645,6 +645,40 @@ def register_pointwise(
return fn
def register_frexp():
"""A pointwise function that maps ops.frexp to inputs"""
name = "frexp"
frexp = ops_wrapper("frexp")
def frexp0(*args, **kwargs):
return frexp(*args, **kwargs)[0]
def frexp1(*args, **kwargs):
return frexp(*args, **kwargs)[1]
pw_fns = [
make_pointwise(frexp0),
make_pointwise(frexp1, override_return_dtype=torch.int32),
]
def fn(*args, **kwargs):
return pw_fns[0](*args, **kwargs), pw_fns[1](*args, **kwargs)
fn = register_lowering(
aten.frexp,
)(fn)
if hasattr(prims, name):
register_lowering(
getattr(prims, name),
type_promotion_kind=None,
)(fn)
return fn
register_frexp()
def register_foreach_pointwise(
aten_fn,
pointwise_lowering_fn,
@ -2278,7 +2312,6 @@ make_fallback(aten._efficientzerotensor)
make_fallback(aten._embedding_bag_per_sample_weights_backward)
make_fallback(aten.fractional_max_pool2d)
make_fallback(aten.fractional_max_pool3d)
make_fallback(aten.frexp)
make_fallback(aten.geqrf)
make_fallback(aten.histc)
make_fallback(aten.kthvalue)

View File

@ -6,6 +6,7 @@ import sympy
from typing_extensions import Protocol
import torch
import torch.utils._pytree as pytree
from torch.fx.graph import inplace_methods, magic_methods
from .utils import IndentedBuffer, reduction_num_outputs, sympy_index_symbol, sympy_str
@ -304,6 +305,9 @@ class OpsHandler(Protocol[T]):
def erfinv(self, x0: T) -> T:
...
def frexp(self, x0: T):
...
def hypot(self, x0: T, x1: T) -> T:
...
@ -502,6 +506,10 @@ class MockHandler:
def masked(mask, body, other) -> str:
return f"ops.masked({mask}, {body()}, {other})"
@staticmethod
def frexp(x):
return (f"ops.frexp({x})[0]", f"ops.frexp({x})[1]")
@staticmethod
def indirect_indexing(index_var, size, check=True) -> sympy.Symbol:
return sympy_index_symbol(f"({str(index_var)})")
@ -567,10 +575,14 @@ class KernelFormatterHandler:
line = getattr(self.parent_handler, name)(*args, **kwargs)
if name == "indirect_indexing":
return line
# replace line with a new variable name
varname = f"tmp{next(self.var_counter)}"
self.output.writeline(f"{varname} = {line}")
return varname
def write(line):
# replace line with a new variable name
varname = f"tmp{next(self.var_counter)}"
self.output.writeline(f"{varname} = {line}")
return varname
return pytree.tree_map(write, line)
return inner
@ -624,13 +636,17 @@ class OpCounterCSE:
val = getattr(self.parent_handler, name)(*args, **kwargs)
if name == "indirect_indexing":
return val
if val not in self.var_names:
varname = f"tmp{self.op_count}"
self.op_count += 1
self.var_names[val] = varname
return varname
else:
return self.var_names[val]
def count(val):
if val not in self.var_names:
varname = f"tmp{self.op_count}"
self.op_count += 1
self.var_names[val] = varname
return varname
else:
return self.var_names[val]
return pytree.tree_map(count, val)
return inner

View File

@ -322,3 +322,12 @@ def exclusive_scan_decoupled_lookback_64(
tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem="release")
return exclusive_prefix
@triton.jit
def frexp(x):
# TODO(isuruf): use inline_asm_elementwise here
y = tl.math.ilogb(x) + 1
exponent = tl.where(x == 0, 0, y)
mantissa = tl.where(x == 0, 0, tl.math.ldexp(x, -y))
return mantissa, exponent

View File

@ -113,6 +113,7 @@ __all__ = [
"fmax",
"fmin",
"fmod",
"frexp",
"gcd",
"ge",
"gt",
@ -3009,5 +3010,22 @@ fft_c2r = _make_prim(
doc=_fft_c2r_doc,
)
def _frexp_meta(self: TensorLikeType) -> Tuple[TensorLikeType, TensorLikeType]:
torch._check(
self.dtype.is_floating_point,
lambda: "torch.frexp() only supports floating-point dtypes",
)
return torch.empty_like(self), torch.empty_like(self, dtype=torch.int32)
frexp = _make_prim(
schema="frexp(Tensor self) -> (Tensor mantissa, Tensor exponent)",
meta=_frexp_meta,
return_type=(RETURN_TYPE.NEW, RETURN_TYPE.NEW),
impl_aten=torch.frexp,
doc="",
)
register_rng_prims()
register_debug_prims()

View File

@ -1368,6 +1368,12 @@ def fmod(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
return prims.fmod(a, b)
@register_decomposition(aten.frexp)
@out_wrapper("mantissa", "exponent")
def frexp(self: TensorLikeType) -> Tuple[TensorLikeType, TensorLikeType]:
return torch.return_types.frexp(prims.frexp(self))
@_make_elementwise_binary_reference(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
supports_lhs_python_scalar=False,

View File

@ -166,8 +166,12 @@ def torch_decomp_decompositions(func):
from torch._decomp import decomposition_table
decompositions = torch._decomp.decompositions
decomp_attrs = [getattr(decompositions, attr) for attr in dir(decompositions)]
return decomposition_table[func] in decomp_attrs
# Note that the function in the decomposition table might be
# different from the one in the module because of the difference
# in out handling in aten API and torch public API
return decomposition_table[func].__module__.startswith(
"torch._decomp"
) and decomposition_table[func].__name__ in dir(decompositions)
def tree_flatten_only(ty: Type[T], tree: PyTree):

View File

@ -20146,6 +20146,10 @@ python_ref_db = [
# Fails on int32
# https://github.com/pytorch/pytorch/issues/85258
),
ElementwiseUnaryPythonRefInfo(
"_refs.frexp",
torch_opinfo_name="frexp",
),
ElementwiseUnaryPythonRefInfo(
"_refs.frac",
torch_opinfo_name="frac",