diff --git a/aten/src/ATen/native/cpu/MaxPooling.cpp b/aten/src/ATen/native/cpu/MaxPooling.cpp index 2a7a6784b7d7..06d0fe501426 100644 --- a/aten/src/ATen/native/cpu/MaxPooling.cpp +++ b/aten/src/ATen/native/cpu/MaxPooling.cpp @@ -32,13 +32,13 @@ void max_pool1d_impl( Tensor& output, const Tensor& input, const PoolingParams1D& p) { - AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_pool1d_impl", [&] { + AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool1d_impl", [&] { const Tensor in = input.contiguous(); scalar_t* const OP = output.data_ptr(); const scalar_t* const IP = in.data_ptr(); // Value used for padding - constexpr scalar_t FILL = std::numeric_limits::has_infinity + scalar_t FILL = std::numeric_limits::has_infinity ? -std::numeric_limits::infinity() : std::numeric_limits::lowest(); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a6bf6b49b495..fea34055eb50 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -330,12 +330,12 @@ - func: view_as_real(Tensor(a) self) -> Tensor(a) variants: function dispatch: - CPU, CUDA, MPS: view_as_real + CPU, CUDA, MPS, Meta: view_as_real - func: view_as_complex(Tensor(a) self) -> Tensor(a) variants: function dispatch: - CPU, CUDA: view_as_complex + CPU, CUDA, Meta: view_as_complex - func: sgn(Tensor self) -> Tensor variants: function, method diff --git a/docs/source/torch.overrides.rst b/docs/source/torch.overrides.rst index 0630b60c4b17..ce3583afa71e 100644 --- a/docs/source/torch.overrides.rst +++ b/docs/source/torch.overrides.rst @@ -14,6 +14,8 @@ Functions .. autofunction:: get_overridable_functions +.. autofunction:: resolve_name + .. autofunction:: get_testing_overrides .. autofunction:: handle_torch_function diff --git a/test/test_meta.py b/test/test_meta.py index 50e8c9740e99..927c14aae7d8 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -1,7 +1,12 @@ # Owner(s): ["module: primTorch"] import torch +import os +from enum import Enum +from torch.overrides import resolve_name from torch.utils._pytree import tree_map, tree_flatten +import torch.utils._python_dispatch +from torch._prims.utils import is_complex_dtype, corresponding_real_dtype from torch.testing._internal.common_utils import ( TestCase, skipIfCrossRef, @@ -9,598 +14,750 @@ from torch.testing._internal.common_utils import ( TEST_WITH_ASAN, run_tests, ) -from torch.overrides import push_torch_function_mode from torch.testing._internal.common_device_type import ( - onlyCUDA, ops, instantiate_device_type_tests, + onlyCUDA, ) +from torch.testing._internal.logging_tensor import no_dispatch from torch.testing._internal.common_methods_invocations import op_db import torch._prims as prims -import functools +import atexit import re -from functools import partial +from collections import defaultdict import unittest import warnings +bf16 = torch.bfloat16 +f64 = torch.float64 +f32 = torch.float32 +f16 = torch.float16 +c32 = torch.complex32 +c64 = torch.complex64 +c128 = torch.complex128 +i8 = torch.int8 +i16 = torch.int16 +i32 = torch.int32 +i64 = torch.int64 +b8 = torch.bool +u8 = torch.uint8 + +dtype_abbrs = { + torch.bfloat16: 'bf16', + torch.float64: 'f64', + torch.float32: 'f32', + torch.float16: 'f16', + torch.complex32: 'c32', + torch.complex64: 'c64', + torch.complex128: 'c128', + torch.int8: 'i8', + torch.int16: 'i16', + torch.int32: 'i32', + torch.int64: 'i64', + torch.bool: 'b8', + torch.uint8: 'u8', +} + +def safe_is_leaf(t): + try: + return t.is_leaf + except RuntimeError: + # inference mode can trigger this + return False + + +# This is a class for converting multiple tensors into meta tensors which +# share the same view/storage structure. The operation model is you allocate +# one of these, and then call it repeatedly on all the tensors you want to +# convert. It's important to use the same object for tensors you want to +# share storage because this is how we correlate shared storages to the same +# meta storages; similarly, it's important NOT to use the same object for +# unrelated groups of tensors because this class will remember all the +# tensors/storages its seen and therefore leak memory. +class MetaConverter: + def __init__(self): + self.storage_memo = {} + self.tensor_memo = {} + self.hit = 0 + self.miss = 0 + + def successful(self): + return self.hit > 0 and self.miss == 0 + + # NB: doesn't actually return a storage, because meta storage is + # not supported + def meta_storage(self, s): + # NB: TypedStorage is freshly allocated and cannot be used as hash + # key index. + if s._cdata not in self.storage_memo: + self.storage_memo[s._cdata] = torch.empty(s.size(), dtype=s.dtype, device='meta') + return self.storage_memo[s._cdata] + + # This function assumes that it's possible to do the conversion + def meta_tensor(self, t): + if t not in self.tensor_memo: + with torch.inference_mode(t.is_inference()): + if t._is_view(): + # Construct views in two steps: recursively meta-fy their + # base, and then create the view off that. NB: doing it + # directly from storage is WRONG because this won't cause + # version counters to get shared. + assert t._is_view() + base = self.meta_tensor(t._base) + + def is_c_of_r(complex_dtype, real_dtype): + return is_complex_dtype(complex_dtype) and \ + corresponding_real_dtype(complex_dtype) == real_dtype + + if base.dtype == t.dtype: + pass + elif is_c_of_r(base.dtype, t.dtype): + base = torch.view_as_real(base) + elif is_c_of_r(t.dtype, base.dtype): + base = torch.view_as_complex(base) + else: + # This is not guaranteed to succeed. If it fails, it + # means there is another dtype-converting view function + # that hasn't been handled here + base = base.view(t.dtype) + + with torch.enable_grad(): + r = base.as_strided(t.size(), t.stride(), t.storage_offset()) + else: + is_leaf = safe_is_leaf(t) + # Fake up some autograd history. + if t.requires_grad: + r = torch.empty((0,), dtype=t.dtype, device='meta', requires_grad=True) + if not is_leaf: + with torch.enable_grad(): + # The backward function here will be wrong, but + # that's OK; our goal is just to get the metadata + # looking as close as possible; we're not going to + # actually try to backward() on these produced + # metas. TODO: would be safer to install some + # sort of unsupported grad_fn here + r = r.clone() + else: + r = torch.empty((0,), dtype=t.dtype, device='meta') + # As long as meta storage is not supported, need to prevent + # redispatching on set_(Storage, ...) which will choke with + # meta storage + s = self.meta_storage(t.storage()) + with no_dispatch(): + with torch.no_grad(): + r.set_(s, t.storage_offset(), t.size(), t.stride()) + + torch._C._set_conj(r, t.is_conj()) + torch._C._set_neg(r, t.is_neg()) + self.tensor_memo[t] = r + + return self.tensor_memo[t] + + def __call__(self, t): + # TODO: zero tensors? We appear to have eliminated them by + # excluding complex for now + if type(t) is torch.Tensor or type(t) is torch.nn.Parameter: + if any([ + t.is_sparse_csr, t.is_sparse, t.is_mkldnn, t.is_quantized, + t.is_nested, torch._is_functional_tensor(t), + # these are supported in meta conversion but the fallbacks + # don't work + t.is_neg(), t.is_conj(), + # conjugate fallback does not support meta tensors + t.dtype in (torch.complex128, torch.complex64), + ]): + # TODO: sparse should support meta + # NB technically to('meta') does work but our logging + # instrumentation will see the meta conversions and the + # tests all break so we just exclude this. In any case + # the to conversion isn't really right anyhow. + self.miss += 1 + return t + elif any([ + t.device.type in ("lazy", "meta"), t.is_complex(), + # We need a way to test if a tensor is batched but there + # is no official APi to do it + # torch._C._is_batched(t), + ]): + # TODO: this stuff should support storage + # (well, maybe not batched) + self.hit += 1 + return t.to("meta") + else: + self.hit += 1 + r = self.meta_tensor(t) + if type(t) is torch.nn.Parameter: + r = torch.nn.Parameter(r, requires_grad=r.requires_grad) + return r + elif torch.overrides.is_tensor_like(t): + # Blindly converting tensor subclasses to meta can cause + # unpredictable problems; e.g., FX tests will trace meta + # tensors into their trace / some subclasses don't correctly + # support meta. Trying to YOLO this is more trouble than it's + # worth. + self.miss += 1 + return t + else: + # non-Tensor types don't count as hit or miss + return t + + +class TestMetaConverter(TestCase): + def assertSameVersionCounter(self, m1, m2): + # Cannot easily test m1 and m2 have same storage due to + # lack of Storage bindings. Use version counter. + vc = m1._version + self.assertEqual(m2._version, vc) + # Doing it this way ensures that we get VC bump even with leaves + with torch.no_grad(): + m1._base.add_(3) + self.assertNotEqual(m1._version, vc) + self.assertEqual(m2._version, m1._version) + + def test_view_of_non_leaf(self): + x = torch.randn(4, requires_grad=True) + y = x.neg() + z1 = y[:] + z2 = y[:] + to_meta = MetaConverter() + m1 = to_meta(z1) + m2 = to_meta(z2) + self.assertEqual(m1.shape, z1.shape) + self.assertTrue(m1._is_view()) + self.assertFalse(m1._base.is_leaf) + self.assertSameVersionCounter(m1, m2) + + def test_view_of_leaf(self): + x = torch.randn(4, requires_grad=True) + z1 = x[:] + z2 = x[:] + to_meta = MetaConverter() + m1 = to_meta(z1) + m2 = to_meta(z2) + self.assertEqual(m1.shape, z1.shape) + self.assertTrue(m1._is_view()) + self.assertTrue(m1._base.is_leaf) + self.assertSameVersionCounter(m1, m2) + + def test_leaf(self): + x = torch.randn(4, requires_grad=True) + to_meta = MetaConverter() + m = to_meta(x) + self.assertEqual(m.shape, x.shape) + self.assertTrue(m.is_leaf) + self.assertTrue(m.requires_grad) + + def test_non_leaf(self): + x = torch.randn(4, requires_grad=True) + y = x.neg() + to_meta = MetaConverter() + m = to_meta(y) + self.assertEqual(m.shape, y.shape) + self.assertFalse(m.is_leaf) + self.assertTrue(m.requires_grad) + + def test_requires_grad_false(self): + x = torch.randn(4, requires_grad=False) + to_meta = MetaConverter() + m = to_meta(x) + self.assertEqual(m.shape, x.shape) + self.assertFalse(m.requires_grad) + + def test_view_as_real(self): + x = torch.randn(4, dtype=torch.complex64) + y = torch.view_as_real(x) + m = MetaConverter()(y) + self.assertEqual(m.shape, y.shape) + self.assertEqual(m.dtype, y.dtype) + + def test_view_as_complex(self): + x = torch.randn((4, 2), dtype=torch.float32) + y = torch.view_as_complex(x) + m = MetaConverter()(y) + self.assertEqual(m.shape, y.shape) + self.assertEqual(m.dtype, y.dtype) + + def test_view_dtype(self): + x = torch.randn(4, dtype=torch.float32) + y = x.view(dtype=torch.int32) + m = MetaConverter()(y) + self.assertEqual(m.shape, y.shape) + self.assertEqual(m.dtype, y.dtype) + + def test_imag(self): + x = torch.randn(4, dtype=torch.complex64) + y = x.imag + m = MetaConverter()(y) + self.assertEqual(m.shape, y.shape) + self.assertEqual(m.dtype, y.dtype) + self.assertEqual(m.stride(), y.stride()) + self.assertEqual(m.storage_offset(), y.storage_offset()) + + +def assert_ref_meta_equal(test_case, meta_rs, rs, msg_callable): + flat_meta_rs, _ = tree_flatten(meta_rs) + flat_rs, _ = tree_flatten(rs) + test_case.assertEqual(len(flat_meta_rs), len(flat_rs)) + for i, meta_r, r in zip(range(len(flat_rs)), flat_meta_rs, flat_rs): + def test_assert(cond, msg): + if not cond: + raise RuntimeError(f"output {i}: {msg_callable(msg)}") + if not isinstance(r, torch.Tensor): + continue + test_assert(isinstance(meta_r, torch.Tensor), f"but real {i}th result is Tensor") + test_assert(meta_r.dtype == r.dtype, f"but real dtype was {r.dtype}") + test_assert(meta_r.shape == r.shape, f"but real shape was {r.shape}") + # NOTE: this helper is used instead of a direct stride comparison + # because strides of tensors with no elements and dimensions of + # length 1 are not computed consistently + same_strides, _ = prims.utils.check_significant_strides(meta_r, r) + test_assert(same_strides, f"but real stride was {r.stride()}") + test_assert( + meta_r.storage_offset() == r.storage_offset(), + f"but real storage_offset was {r.storage_offset()}") + test_assert(meta_r.requires_grad == r.requires_grad, f"but real requires_grad was {r.requires_grad}") + test_assert(meta_r.is_conj() == r.is_conj(), f"but real is_conj was {r.is_conj()}") + test_assert(meta_r.is_neg() == r.is_neg(), f"but real is_neg was {r.is_neg()}") + + +# This environment variable controls whether or not we print expected failure +# lists at the end of a test suite run. The intended usage looks like this: +# +# 1. Run `PYTORCH_COLLECT_EXPECT=1 python test/test_meta.py` on a CUDA build +# of PyTorch that has LAPACK/MAGMA installed. You can filter `-k test_meta` +# or `-k test_dispatch_meta` to only focus on one or another list +# 2. Given the printed skip/xfail list, add them to the corresponding lists; +# torch.* entries go in meta_function and aten.* entries go in meta_dispatch. +# If there are preexisting entries, you need to merge in the entries. +# +# This is somewhat manual but typically you shouldn't need to do this, unless +# you've made a major change (e.g., added a new dtype to PyTorch) and need to +# refresh the lists. If you want to do it from scratch, just clear out the +# preexisting lists before running. +# +# WARNING: Python dict literals will silently ignore duplicate keys +COLLECT_EXPECT = os.getenv('PYTORCH_COLLECT_EXPECT', '0') == '1' + +seen_succeeded = {} +seen_failed = {} +failed_reasons = defaultdict(set) +def print_seen(): + expected_failures = [] + skips = [] + + def fmt_dtypes(dtypes): + r = ', '.join(sorted(dtype_abbrs[d] for d in dtypes)) + return '{' + r + '}' + + for op, failed_dtypes in seen_failed.items(): + ops = resolve_name(op) + succeeded_dtypes = seen_succeeded.get(op, set()) + expected_failures_dtypes = failed_dtypes - succeeded_dtypes + skips_dtypes = failed_dtypes & succeeded_dtypes + reasons = "" + if failed_reasons[op]: + reasons = " # " + ", ".join(sorted(failed_reasons[op])) + if expected_failures_dtypes: + expected_failures.append(f" {ops}: {fmt_dtypes(expected_failures_dtypes)},{reasons}") + if skips_dtypes: + skips.append(f" {ops}: {fmt_dtypes(skips_dtypes)},") + expected_failures.sort() + skips.sort() + nl = '\n' + print(f"""\ +expected_failures = {{ +{nl.join(expected_failures)} +}} + +skips = {{ +{nl.join(skips)} +}} +""") +if COLLECT_EXPECT: + atexit.register(print_seen) + +# Success forces pass; failure forces fail; skip unconditionally skips testing +TestExpect = Enum("TestExpect", ("SUCCESS", "XFAILURE", "SKIP")) + +# unlike print produce strides +def verbose_print(e): + class Lit: + def __init__(self, s): + self.s = s + + def __repr__(self): + return self.s + + def go(t): + if isinstance(t, torch.Tensor): + return Lit(f"{t} stride={t.stride()}") + else: + return t + + return repr(tree_map(go, e)) + +def run_meta_crossref( + test_case, + test_expect, + func, + args, + kwargs, + *, + dtype, + device_type, +): + to_meta = MetaConverter() + do_meta = test_expect is not TestExpect.SKIP + + if do_meta: + try: + meta_args = tree_map(to_meta, args) + meta_kwargs = tree_map(to_meta, kwargs) + except Exception as e: + raise RuntimeError( + f"failed to convert args to meta; " + f"originally (*{args}, **{kwargs})") from e + + rs = func(*args, **kwargs) + + # TODO: also handle cases where func raise an exception + + # For now, only attempt if we managed to convert all tensor types + # (if any of them failed, we're in a mixed device situation and + # this isn't well supported) + if do_meta and to_meta.successful(): + try: + # Suppress warnings, this doesn't matter for test_meta.py + # but it does matter if you want to use this decorator + # for cross-ref testing, as some tests may be looking at + # errors + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + meta_rs = func(*meta_args, **meta_kwargs) + except Exception as e: + if test_expect is TestExpect.XFAILURE: + return rs + seen_failed.setdefault(func, set()).add(dtype) + if isinstance(e, NotImplementedError): + m = RE_NOT_IMPLEMENTED_MSG.search(e.args[0]) + if m: + failed_reasons[func].add(m.group(1)) + if COLLECT_EXPECT: + return rs + raise RuntimeError(f"""\ +failed to run: {resolve_name(func)}( +*{verbose_print(meta_args)}, +**{verbose_print(meta_kwargs)} +)""") from e + else: + try: + delim = ',\n ' + assert_ref_meta_equal(test_case, meta_rs, rs, lambda msg: f"""\ +meta disagrees with real impl: +{resolve_name(func)}( + {delim.join(map(verbose_print, meta_args))}, + {delim.join(k + ": " + verbose_print(v) for k, v in meta_kwargs.items())} +) = ( + {verbose_print(meta_rs)} +) +{msg} +""") + except Exception: + if test_expect is TestExpect.XFAILURE: + return rs + seen_failed.setdefault(func, set()).add(dtype) + if COLLECT_EXPECT: + return rs + raise + else: + seen_succeeded.setdefault(func, set()).add(dtype) + if test_expect is TestExpect.XFAILURE and not COLLECT_EXPECT: + raise RuntimeError(f"unexpected success {resolve_name(func)}") + + return rs + + + RE_NOT_IMPLEMENTED_MSG = re.compile(r"Could not run '([^']+)' with arguments ") -# These just need an implementation of meta tensors, once you -# implement them remove from this set. When doing comprehensive -# testing, we will verify that these raise errors when meta is run under -# OpInfo -meta_exclude_set = { - torch.Tensor.__lshift__, # MISSING aten::__lshift__.Scalar - torch.Tensor.__lshift__, # MISSING aten::__lshift__.Tensor - torch.Tensor.__rmatmul__, # MISSING aten::dot - torch.Tensor.__rshift__, # MISSING aten::__rshift__.Scalar - torch.Tensor.__rshift__, # MISSING aten::__rshift__.Tensor - torch.Tensor.addbmm, # MISSING aten::addbmm - torch.Tensor.addcmul, # MISSING aten::_local_scalar_dense - torch.Tensor.angle, # MISSING aten::angle - torch.Tensor.argsort, # MISSING aten::sort - torch.Tensor.bincount, # MISSING aten::bincount - torch.Tensor.cholesky, # MISSING aten::cholesky - torch.Tensor.cholesky_inverse, # MISSING aten::cholesky_inverse - torch.Tensor.cholesky_solve, # MISSING aten::_cholesky_solve_helper - torch.Tensor.clamp, # MISSING aten::clamp.Tensor - torch.Tensor.clamp_, # MISSING aten::clamp.Tensor_out - torch.Tensor.clip, # MISSING aten::clamp.Tensor - torch.Tensor.clip_, # MISSING aten::clamp.Tensor_out - torch.Tensor.conj_physical, # MISSING aten::conj_physical.out - torch.Tensor.corrcoef, # MISSING aten::_local_scalar_dense - torch.Tensor.count_nonzero, # MISSING aten::count_nonzero.dim_IntList - torch.Tensor.cov, # MISSING aten::_local_scalar_dense - torch.Tensor.cummax, # MISSING aten::_cummax_helper - torch.Tensor.cummin, # MISSING aten::_cummin_helper - torch.Tensor.cumprod_, # MISSING aten::logical_and.out - torch.Tensor.dequantize, # MISSING aten::dequantize.self - torch.Tensor.det, # MISSING aten::_det_lu_based_helper - torch.Tensor.diag, # MISSING aten::diag.out - torch.Tensor.diagflat, # MISSING aten::diag.out - torch.Tensor.dot, # MISSING aten::dot - torch.Tensor.eig, # MISSING aten::_local_scalar_dense - torch.Tensor.equal, # MISSING aten::equal - torch.Tensor.floor_divide, # MISSING aten::floor_divide - torch.Tensor.frexp, # MISSING aten::frexp.Tensor_out - torch.Tensor.geqrf, # MISSING aten::geqrf - torch.Tensor.histc, # MISSING aten::histc - torch.Tensor.histogram, # MISSING aten::histogram.bin_ct - torch.Tensor.inverse, # MISSING aten::_local_scalar_dense - torch.Tensor.is_set_to, # MISSING aten::is_set_to - torch.Tensor.istft, # MISSING aten::view_as_complex - torch.Tensor.kthvalue, # MISSING aten::kthvalue.values - torch.Tensor.logcumsumexp, # MISSING aten::_logcumsumexp - torch.Tensor.logdet, # MISSING aten::_local_scalar_dense - torch.Tensor.logical_and_, # MISSING aten::logical_and.out - torch.Tensor.logical_not, # MISSING aten::logical_not.out - torch.Tensor.logical_or_, # MISSING aten::logical_or.out - torch.Tensor.logical_xor, # MISSING aten::logical_xor.out - torch.Tensor.logical_xor_, # MISSING aten::logical_xor.out - torch.Tensor.logit, # MISSING aten::logit - torch.Tensor.logsumexp, # MISSING aten::abs - torch.Tensor.lstsq, # MISSING aten::lstsq - torch.Tensor.masked_select, # MISSING aten::masked_select - torch.Tensor.matmul, # MISSING aten::dot - torch.Tensor.matrix_exp, # MISSING aten::linalg_matrix_exp - torch.Tensor.matrix_power, # MISSING aten::eye.m_out - torch.Tensor.median, # MISSING aten::median - torch.Tensor.median, # MISSING aten::median.dim_values - torch.Tensor.mode, # MISSING aten::mode - torch.Tensor.msort, # MISSING aten::sort - torch.Tensor.multinomial, # MISSING aten::multinomial - torch.Tensor.mvlgamma, # MISSING aten::_local_scalar_dense - torch.Tensor.mvlgamma_, # MISSING aten::_local_scalar_dense - torch.Tensor.nan_to_num, # MISSING aten::nan_to_num.out - torch.Tensor.nan_to_num_, # MISSING aten::nan_to_num.out - torch.Tensor.nanmean, # MISSING aten::logical_not.out - torch.Tensor.nanmedian, # MISSING aten::nanmedian - torch.Tensor.nanmedian, # MISSING aten::nanmedian.dim_values - torch.Tensor.nanquantile, # MISSING aten::sort - torch.Tensor.nansum, # MISSING aten::nansum - torch.Tensor.narrow, # MISSING aten::_local_scalar_dense - torch.Tensor.nonzero, # MISSING aten::nonzero - torch.Tensor.orgqr, # MISSING aten::linalg_householder_product - torch.Tensor.ormqr, # MISSING aten::ormqr - torch.Tensor.prod, # MISSING aten::prod - torch.Tensor.qr, # MISSING aten::_linalg_qr_helper - torch.Tensor.quantile, # MISSING aten::sort - torch.Tensor.relu, # MISSING aten::relu - torch.Tensor.renorm_, # MISSING aten::_local_scalar_dense - torch.Tensor.repeat_interleave, # MISSING aten::repeat_interleave.Tensor - torch.Tensor.roll, # MISSING aten::roll - torch.Tensor.slogdet, # MISSING aten::linalg_slogdet - torch.Tensor.solve, # MISSING aten::_solve_helper - torch.Tensor.sort, # MISSING aten::sort - torch.Tensor.std, # MISSING aten::std.correction - torch.Tensor.stft, # MISSING aten::_fft_r2c - torch.Tensor.symeig, # MISSING aten::_symeig_helper - torch.Tensor.take, # MISSING aten::take - torch.Tensor.to_mkldnn, # MISSING aten::to_mkldnn - torch.Tensor.to_sparse, # MISSING aten::to_sparse - torch.Tensor.to_sparse_csr, # MISSING aten::to_sparse_csr - torch.Tensor.topk, # MISSING aten::_local_scalar_dense - torch.Tensor.trace, # MISSING aten::trace - torch.Tensor.unique, # MISSING aten::_unique2 - torch.Tensor.unique_consecutive, # MISSING aten::unique_consecutive - torch.Tensor.unsqueeze, # MISSING aten::_local_scalar_dense - torch.Tensor.var, # MISSING aten::var.correction - torch.Tensor.vdot, # MISSING aten::vdot - torch._add_relu, # MISSING aten::_add_relu.Tensor - torch._aminmax, # MISSING aten::_aminmax - torch._assert_async, # MISSING aten::_assert_async - torch._compute_linear_combination, # MISSING aten::_compute_linear_combination - torch._det_lu_based_helper, # MISSING aten::_det_lu_based_helper - torch._dirichlet_grad, # MISSING aten::_dirichlet_grad - torch._fake_quantize_learnable_per_channel_affine, # MISSING aten::_fake_quantize_learnable_per_channel_affine - torch._fake_quantize_learnable_per_tensor_affine, # MISSING aten::_fake_quantize_learnable_per_tensor_affine - torch._fake_quantize_per_tensor_affine_cachemask_tensor_qparams, # MISSING aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams # noqa: E501 - torch._foreach_abs, # MISSING aten::_foreach_abs - torch._foreach_abs_, # MISSING aten::_foreach_abs_ - torch._foreach_acos, # MISSING aten::_foreach_acos - torch._foreach_acos_, # MISSING aten::_foreach_acos_ - torch._foreach_add, # MISSING aten::_foreach_add.Scalar - torch._foreach_add_, # MISSING aten::_foreach_add_.Scalar - torch._foreach_addcdiv, # MISSING aten::_foreach_addcdiv.Scalar - torch._foreach_addcdiv_, # MISSING aten::_foreach_addcdiv_.Scalar - torch._foreach_addcmul, # MISSING aten::_foreach_addcmul.Scalar - torch._foreach_addcmul_, # MISSING aten::_foreach_addcmul_.Scalar - torch._foreach_asin, # MISSING aten::_foreach_asin - torch._foreach_asin_, # MISSING aten::_foreach_asin_ - torch._foreach_atan, # MISSING aten::_foreach_atan - torch._foreach_atan_, # MISSING aten::_foreach_atan_ - torch._foreach_ceil, # MISSING aten::_foreach_ceil - torch._foreach_ceil_, # MISSING aten::_foreach_ceil_ - torch._foreach_cos, # MISSING aten::_foreach_cos - torch._foreach_cos_, # MISSING aten::_foreach_cos_ - torch._foreach_cosh, # MISSING aten::_foreach_cosh - torch._foreach_cosh_, # MISSING aten::_foreach_cosh_ - torch._foreach_div, # MISSING aten::_foreach_div.Scalar - torch._foreach_div_, # MISSING aten::_foreach_div_.ScalarList - torch._foreach_erf, # MISSING aten::_foreach_erf - torch._foreach_erf_, # MISSING aten::_foreach_erf_ - torch._foreach_erfc, # MISSING aten::_foreach_erfc - torch._foreach_erfc_, # MISSING aten::_foreach_erfc_ - torch._foreach_exp, # MISSING aten::_foreach_exp - torch._foreach_exp_, # MISSING aten::_foreach_exp_ - torch._foreach_expm1, # MISSING aten::_foreach_expm1 - torch._foreach_expm1_, # MISSING aten::_foreach_expm1_ - torch._foreach_floor, # MISSING aten::_foreach_floor - torch._foreach_floor_, # MISSING aten::_foreach_floor_ - torch._foreach_frac, # MISSING aten::_foreach_frac - torch._foreach_frac_, # MISSING aten::_foreach_frac_ - torch._foreach_log, # MISSING aten::_foreach_log - torch._foreach_log10, # MISSING aten::_foreach_log10 - torch._foreach_log10_, # MISSING aten::_foreach_log10_ - torch._foreach_log1p, # MISSING aten::_foreach_log1p - torch._foreach_log1p_, # MISSING aten::_foreach_log1p_ - torch._foreach_log2, # MISSING aten::_foreach_log2 - torch._foreach_log2_, # MISSING aten::_foreach_log2_ - torch._foreach_log_, # MISSING aten::_foreach_log_ - torch._foreach_maximum, # MISSING aten::_foreach_maximum.List - torch._foreach_minimum, # MISSING aten::_foreach_minimum.List - torch._foreach_mul, # MISSING aten::_foreach_mul.Scalar - torch._foreach_mul_, # MISSING aten::_foreach_mul_.ScalarList - torch._foreach_neg, # MISSING aten::_foreach_neg - torch._foreach_neg_, # MISSING aten::_foreach_neg_ - torch._foreach_norm, # MISSING aten::_foreach_norm.Scalar - torch._foreach_reciprocal, # MISSING aten::_foreach_reciprocal - torch._foreach_reciprocal_, # MISSING aten::_foreach_reciprocal_ - torch._foreach_round, # MISSING aten::_foreach_round - torch._foreach_round_, # MISSING aten::_foreach_round_ - torch._foreach_sigmoid, # MISSING aten::_foreach_sigmoid - torch._foreach_sigmoid_, # MISSING aten::_foreach_sigmoid_ - torch._foreach_sin, # MISSING aten::_foreach_sin - torch._foreach_sin_, # MISSING aten::_foreach_sin_ - torch._foreach_sinh, # MISSING aten::_foreach_sinh - torch._foreach_sinh_, # MISSING aten::_foreach_sinh_ - torch._foreach_sqrt, # MISSING aten::_foreach_sqrt - torch._foreach_sqrt_, # MISSING aten::_foreach_sqrt_ - torch._foreach_sub, # MISSING aten::_foreach_sub.Scalar - torch._foreach_sub_, # MISSING aten::_foreach_sub_.ScalarList - torch._foreach_tan, # MISSING aten::_foreach_tan - torch._foreach_tan_, # MISSING aten::_foreach_tan_ - torch._foreach_tanh, # MISSING aten::_foreach_tanh - torch._foreach_tanh_, # MISSING aten::_foreach_tanh_ - torch._foreach_trunc, # MISSING aten::_foreach_trunc - torch._foreach_trunc_, # MISSING aten::_foreach_trunc_ - torch._foreach_zero_, # MISSING aten::_foreach_zero_ - torch._fused_moving_avg_obs_fq_helper, # MISSING aten::_fused_moving_avg_obs_fq_helper - torch._make_per_tensor_quantized_tensor, # MISSING aten::_make_per_tensor_quantized_tensor - torch._masked_softmax, # MISSING aten::_masked_softmax - torch._sample_dirichlet, # MISSING aten::_sample_dirichlet - torch._standard_gamma, # MISSING aten::_standard_gamma - torch._unique, # MISSING aten::_unique - torch._unique2, # MISSING aten::_unique2 - torch.addbmm, # MISSING aten::addbmm - torch.angle, # MISSING aten::angle - torch.batch_norm, # MISSING aten::native_batch_norm - torch.bernoulli, # MISSING aten::bernoulli.out - torch.bincount, # MISSING aten::bincount - torch.binomial, # MISSING aten::binomial - torch.bucketize, # MISSING aten::bucketize.Tensor - torch.cholesky, # MISSING aten::cholesky - torch.cholesky_inverse, # MISSING aten::cholesky_inverse - torch.cholesky_solve, # MISSING aten::_cholesky_solve_helper - torch.clip, # MISSING aten::clamp.Tensor - torch.combinations, # MISSING aten::masked_select - torch.complex, # MISSING aten::complex.out - torch.conj_physical, # MISSING aten::conj_physical.out - torch.corrcoef, # MISSING aten::_local_scalar_dense - torch.count_nonzero, # MISSING aten::count_nonzero.dim_IntList - torch.cov, # MISSING aten::_local_scalar_dense - torch.cummax, # MISSING aten::_cummax_helper - torch.cummin, # MISSING aten::_cummin_helper - torch.det, # MISSING aten::_det_lu_based_helper - torch.diag, # MISSING aten::diag.out - torch.diagflat, # MISSING aten::diag.out - torch.dot, # MISSING aten::dot - torch.eig, # MISSING aten::_local_scalar_dense - torch.equal, # MISSING aten::equal - torch.eye, # MISSING aten::eye.m_out - torch.fake_quantize_per_channel_affine, # MISSING aten::fake_quantize_per_channel_affine_cachemask - torch.fake_quantize_per_tensor_affine, # MISSING aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams - torch.fft.fft, # MISSING aten::_fft_r2c - torch.fft.fft2, # MISSING aten::_fft_c2c - torch.fft.fftn, # MISSING aten::_fft_c2c - torch.fft.fftshift, # MISSING aten::roll - torch.fft.hfft2, # MISSING aten::_fft_c2c - torch.fft.hfftn, # MISSING aten::_fft_c2c - torch.fft.ifft, # MISSING aten::_fft_r2c - torch.fft.ifft2, # MISSING aten::_fft_c2c - torch.fft.ifftn, # MISSING aten::_fft_c2c - torch.fft.ifftshift, # MISSING aten::roll - torch.fft.ihfft, # MISSING aten::_fft_r2c - torch.fft.ihfft2, # MISSING aten::_fft_r2c - torch.fft.ihfftn, # MISSING aten::_fft_r2c - torch.fft.irfft, # MISSING aten::_fft_c2r - torch.fft.irfft2, # MISSING aten::_fft_c2r - torch.fft.irfftn, # MISSING aten::_fft_c2r - torch.fft.rfft, # MISSING aten::_fft_r2c - torch.fft.rfft2, # MISSING aten::_fft_r2c - torch.fft.rfftn, # MISSING aten::_fft_r2c - torch.floor_divide, # MISSING aten::floor_divide - torch.frexp, # MISSING aten::frexp.Tensor_out - torch.functional.cdist, # MISSING aten::_cdist_forward - torch.functional.einsum, # MISSING aten::dot - torch.functional.istft, # MISSING aten::view_as_complex - torch.functional.pca_lowrank, # MISSING aten::_linalg_qr_helper - torch.functional.stft, # MISSING aten::_fft_r2c - torch.functional.svd_lowrank, # MISSING aten::_linalg_qr_helper - torch.functional.tensordot, # MISSING aten::tensordot.out - torch.functional.unique, # MISSING aten::_unique2 - torch.functional.unique_consecutive, # MISSING aten::unique_consecutive - torch.fused_moving_avg_obs_fake_quant, # MISSING aten::_fused_moving_avg_obs_fq_helper - torch.geqrf, # MISSING aten::geqrf - torch.group_norm, # MISSING aten::native_batch_norm - torch.histc, # MISSING aten::histc.out - torch.histogram, # MISSING aten::histogram.bin_ct - torch.histogramdd, # MISSING aten::_histogramdd_bin_edges - torch.inner, # MISSING aten::tensordot.out - torch.inverse, # MISSING aten::_local_scalar_dense - torch.kthvalue, # MISSING aten::kthvalue.values - torch.layer_norm, # MISSING aten::native_batch_norm - torch.linalg.cholesky, # MISSING aten::linalg_cholesky_ex - torch.linalg.cholesky_ex, # MISSING aten::linalg_cholesky_ex - torch.linalg.det, # MISSING aten::_det_lu_based_helper - torch.linalg.eig, # MISSING aten::linalg_eig - torch.linalg.eig, # MISSING aten::linalg_eig.out - torch.linalg.eigh, # MISSING aten::linalg_eigh - torch.linalg.eigvals, # MISSING aten::linalg_eig - torch.linalg.eigvalsh, # MISSING aten::linalg_eigh - torch.linalg.eigvalsh, # MISSING aten::linalg_eigvalsh.out - torch.linalg.householder_product, # MISSING aten::linalg_householder_product - torch.linalg.inv, # MISSING aten::_local_scalar_dense - torch.linalg.lstsq, # MISSING aten::linalg_lstsq.out - torch.linalg.lu_factor, # MISSING aten::_local_scalar_dense - torch.linalg.matmul, # MISSING aten::dot - torch.linalg.matrix_exp, # MISSING aten::linalg_matrix_exp - torch.linalg.matrix_power, # MISSING aten::_local_scalar_dense - torch.linalg.matrix_power, # MISSING aten::eye.m_out - torch.linalg.qr, # MISSING aten::_linalg_qr_helper - torch.linalg.slogdet, # MISSING aten::linalg_slogdet - torch.linalg.solve, # MISSING aten::linalg_solve - torch.linalg.solve_triangular, # MISSING aten::linalg_solve_triangular - torch.linalg.tensorinv, # MISSING aten::_local_scalar_dense - torch.linalg.tensorsolve, # MISSING aten::linalg_solve - torch.logcumsumexp, # MISSING aten::_logcumsumexp - torch.logdet, # MISSING aten::_local_scalar_dense - torch.logical_not, # MISSING aten::logical_not.out - torch.logical_xor, # MISSING aten::logical_xor.out - torch.logit, # MISSING aten::logit - torch.lstsq, # MISSING aten::lstsq - torch.lu_solve, # MISSING aten::lu_solve - torch.masked_select, # MISSING aten::masked_select - torch.matmul, # MISSING aten::dot - torch.matrix_exp, # MISSING aten::linalg_matrix_exp - torch.matrix_power, # MISSING aten::eye.m_out - torch.matrix_rank, # MISSING aten::linalg_eigvalsh.out - torch.median, # MISSING aten::median - torch.median, # MISSING aten::median.dim_values - torch.mode, # MISSING aten::mode - torch.multinomial, # MISSING aten::multinomial - torch.mvlgamma, # MISSING aten::_local_scalar_dense - torch.nan_to_num, # MISSING aten::nan_to_num.out - torch.nanmean, # MISSING aten::logical_not.out - torch.nanmedian, # MISSING aten::nanmedian - torch.nanmedian, # MISSING aten::nanmedian.dim_values - torch.nansum, # MISSING aten::nansum - torch.nn.functional.adaptive_avg_pool1d, # MISSING aten::_adaptive_avg_pool2d - torch.nn.functional.adaptive_avg_pool2d, # MISSING aten::_adaptive_avg_pool2d - torch.nn.functional.adaptive_avg_pool3d, # MISSING aten::_adaptive_avg_pool3d - torch.nn.functional.batch_norm, # MISSING aten::native_batch_norm - torch.nn.functional.binary_cross_entropy, # MISSING aten::binary_cross_entropy - torch.nn.functional.channel_shuffle, # MISSING aten::channel_shuffle - torch.nn.functional.cross_entropy, # MISSING aten::_local_scalar_dense - torch.nn.functional.cross_entropy, # MISSING aten::nll_loss2d_forward - torch.nn.functional.ctc_loss, # MISSING aten::_ctc_loss - torch.nn.functional.embedding_bag, # MISSING aten::_embedding_bag - torch.nn.functional.fold, # MISSING aten::col2im - torch.nn.functional.gaussian_nll_loss, # MISSING aten::_local_scalar_dense - torch.nn.functional.grid_sample, # MISSING aten::grid_sampler_2d - torch.nn.functional.group_norm, # MISSING aten::native_batch_norm - torch.nn.functional.hardswish, # MISSING aten::hardswish - torch.nn.functional.hardtanh, # MISSING aten::hardtanh - torch.nn.functional.instance_norm, # MISSING aten::native_batch_norm - torch.nn.functional.layer_norm, # MISSING aten::native_batch_norm - torch.nn.functional.logsigmoid, # MISSING aten::log_sigmoid_forward - torch.nn.functional.max_pool3d, # MISSING aten::max_pool3d_with_indices - torch.nn.functional.max_pool3d_with_indices, # MISSING aten::max_pool3d_with_indices - torch.nn.functional.max_unpool1d, # MISSING aten::max_unpool2d - torch.nn.functional.max_unpool2d, # MISSING aten::max_unpool2d - torch.nn.functional.max_unpool3d, # MISSING aten::max_unpool3d - torch.nn.functional.multi_head_attention_forward, # MISSING aten::logical_or.out - torch.nn.functional.multi_margin_loss, # MISSING aten::multi_margin_loss - torch.nn.functional.multilabel_margin_loss, # MISSING aten::multilabel_margin_loss_forward - torch.nn.functional.multilabel_soft_margin_loss, # MISSING aten::log_sigmoid_forward - torch.nn.functional.nll_loss, # MISSING aten::nll_loss2d_forward - torch.nn.functional.one_hot, # MISSING aten::_local_scalar_dense - torch.nn.functional.pdist, # MISSING aten::_pdist_forward - torch.nn.functional.prelu, # MISSING aten::prelu - torch.nn.functional.relu, # MISSING aten::relu - torch.nn.functional.relu6, # MISSING aten::hardtanh - torch.nn.functional.rrelu, # MISSING aten::rrelu_with_noise - torch.nn.functional.unfold, # MISSING aten::im2col - torch.nonzero, # MISSING aten::nonzero - torch.normal, # MISSING aten::_local_scalar_dense - torch.orgqr, # MISSING aten::linalg_householder_product - torch.ormqr, # MISSING aten::ormqr - torch.poisson, # MISSING aten::poisson - torch.polar, # MISSING aten::polar.out - torch.prod, # MISSING aten::prod - torch.qr, # MISSING aten::_linalg_qr_helper - torch.quantize_per_channel, # MISSING aten::quantize_per_channel - torch.quantize_per_tensor, # MISSING aten::quantize_per_tensor - torch.quantize_per_tensor_dynamic, # MISSING aten::quantize_per_tensor_dynamic - torch.relu, # MISSING aten::relu - torch.repeat_interleave, # MISSING aten::repeat_interleave.Tensor - torch.rnn_relu, # MISSING aten::relu - torch.rnn_relu_cell, # MISSING aten::relu - torch.roll, # MISSING aten::roll - torch.rsub, # MISSING aten::rsub.Tensor - torch.searchsorted, # MISSING aten::searchsorted.Tensor - torch.slogdet, # MISSING aten::linalg_slogdet - torch.solve, # MISSING aten::_solve_helper - torch.special.logit, # MISSING aten::logit - torch.special.logsumexp, # MISSING aten::abs.out - torch.special.multigammaln, # MISSING aten::_local_scalar_dense - torch.square, # MISSING aten::square.out - torch.std, # MISSING aten::std.correction - torch.std_mean, # MISSING aten::std_mean.correction - torch.symeig, # MISSING aten::_symeig_helper - torch.take, # MISSING aten::take - torch.threshold, # MISSING aten::_local_scalar_dense - torch.trace, # MISSING aten::trace - torch.var, # MISSING aten::var.correction - torch.var_mean, # MISSING aten::var_mean.correction - torch.vdot, # MISSING aten::vdot - torch.nanquantile, # MISSING aten::logical_not.out + +meta_function_expected_failures = { + torch.Tensor.item: {b8, bf16, c128, c64, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::_local_scalar_dense + torch.Tensor.to_sparse: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::to_sparse, aten::to_sparse.sparse_dim + torch.addbmm: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::addbmm, aten::addbmm.out + torch.allclose: {bf16, f16, f32, f64}, # aten::_local_scalar_dense + torch.angle: {c32, b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::angle, aten::angle.out + torch.argwhere: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::nonzero + torch.bincount: {i16, i32, i64, i8, u8}, # aten::bincount + torch.bucketize: {bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::bucketize.Tensor, aten::bucketize.Tensor_out + torch.combinations: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::masked_select + torch.complex: {f16, f32, f64}, # aten::complex.out + torch.conj_physical: {c32}, # aten::conj_physical.out + torch.corrcoef: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_local_scalar_dense + torch.count_nonzero: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::count_nonzero.dim_IntList + torch.cov: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_local_scalar_dense + torch.diag: {bf16, b8, f32, f64, i16, i32, i64, i8, u8}, # aten::diag.out + torch.diagflat: {bf16, b8, f32, f64, i16, i32, i64, i8, u8}, # aten::diag.out + torch.dot: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::dot + torch.fft.fft2: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c + torch.fft.fft: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c + torch.fft.fftn: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c + torch.fft.fftshift: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::roll + torch.fft.hfft2: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c + torch.fft.hfft: {b8, f32, f64, i16, i32, i64, i8, u8}, + torch.fft.hfftn: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c + torch.fft.ifft2: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c + torch.fft.ifft: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c + torch.fft.ifftn: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c + torch.fft.ifftshift: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::roll + torch.fft.ihfft2: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c + torch.fft.ihfft: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c + torch.fft.ihfftn: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c + torch.fft.irfft2: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2r, aten::_fft_c2r.out + torch.fft.irfft: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2r, aten::_fft_c2r.out + torch.fft.irfftn: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2r, aten::_fft_c2r.out + torch.fft.rfft2: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c + torch.fft.rfft: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c + torch.fft.rfftn: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c + torch.floor_divide: {bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::floor_divide, aten::floor_divide.out + torch.frexp: {bf16, f16, f32, f64}, # aten::frexp.Tensor_out + torch.functional.istft: {f32, f64}, # aten::view_as_complex + torch.functional.stft: {f32, f64}, # aten::_fft_r2c + torch.functional.unique: {b8, bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_unique2, aten::unique_dim + torch.functional.unique_consecutive: {b8, bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::unique_consecutive + torch.histc: {bf16, f32, f64}, # aten::histc, aten::histc.out + torch.histogram: {f32, f64}, # aten::histogram.bin_ct, aten::histogram.bins_tensor + torch.histogramdd: {f32, f64}, # aten::_histogramdd_bin_edges, aten::_histogramdd_from_bin_tensors + torch.kthvalue: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::kthvalue.values + torch.linalg.qr: {f32, f64}, # aten::_linalg_qr_helper + torch.logcumsumexp: {bf16, f32, f64}, # aten::_logcumsumexp, aten::_logcumsumexp.out + torch.masked_select: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::masked_select, aten::masked_select.out + torch.matrix_exp: {bf16, f32, f64}, # aten::linalg_matrix_exp + torch.median: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::median, aten::median.dim_values + torch.mode: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::mode + torch.multinomial: {bf16, f32, f64}, # aten::multinomial, aten::multinomial.out + torch.mvlgamma: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_local_scalar_dense, aten::mvlgamma.out + torch.nan_to_num: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::nan_to_num.out + torch.nanmean: {bf16, f16, f32, f64}, + torch.nanmedian: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::nanmedian, aten::nanmedian.dim_values + torch.nanquantile: {f32, f64}, + torch.nansum: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::nansum, aten::nansum.out + torch.nn.functional.adaptive_avg_pool2d: {bf16, f32, f64}, # aten::_adaptive_avg_pool2d + torch.nn.functional.conv1d: {bf16, f32, f64, i64}, + torch.nn.functional.conv2d: {bf16, f32, f64, i64}, + torch.nn.functional.conv_transpose1d: {f32, f64, i64}, + torch.nn.functional.conv_transpose2d: {f32, f64, i64}, + torch.nn.functional.conv_transpose3d: {f32, f64, i64}, + torch.nn.functional.ctc_loss: {f32, f64}, + torch.nn.functional.embedding_bag: {f16, f32, f64}, # aten::_embedding_bag_forward_only + torch.nn.functional.gaussian_nll_loss: {bf16, f32, f64}, # aten::_local_scalar_dense + torch.nn.functional.grid_sample: {f32, f64}, # aten::grid_sampler_2d, aten::grid_sampler_3d + torch.nn.functional.group_norm: {bf16, f32, f64}, # aten::var_mean.correction + torch.nn.functional.instance_norm: {f32, f64}, # aten::var_mean.correction + torch.nn.functional.layer_norm: {bf16, f32, f64}, + torch.nn.functional.max_pool3d: {f32, f64}, # aten::max_pool3d_with_indices + torch.nn.functional.max_pool3d_with_indices: {f32, f64}, # aten::max_pool3d_with_indices + torch.nn.functional.max_unpool1d: {f32, f64}, # aten::max_unpool2d + torch.nn.functional.max_unpool2d: {f32, f64}, # aten::max_unpool2d + torch.nn.functional.max_unpool3d: {f32, f64}, # aten::max_unpool3d + torch.nn.functional.multi_margin_loss: {f32, f64}, # aten::multi_margin_loss + torch.nn.functional.multilabel_margin_loss: {f32, f64}, # aten::multilabel_margin_loss_forward + torch.nn.functional.one_hot: {i64}, # aten::_local_scalar_dense + torch.nn.functional.pdist: {f32, f64}, # aten::_pdist_forward + torch.nn.functional.prelu: {bf16, f32, f64}, # aten::prelu + torch.nn.functional.relu: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::relu + torch.nn.functional.rrelu: {bf16, f32, f64}, # aten::rrelu_with_noise + torch.nn.functional.unfold: {bf16, f16, f32, f64}, # aten::im2col + torch.nonzero: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::nonzero, aten::nonzero.out + torch.polar: {f32, f64}, # aten::polar.out + torch.repeat_interleave: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::repeat_interleave.Tensor + torch.roll: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::roll + torch.searchsorted: {bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::searchsorted.Tensor, aten::searchsorted.Tensor_out + torch.symeig: {f32, f64}, + torch.std_mean: {bf16, f16, f32, f64}, # aten::std_mean.correction + torch.take: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::take, aten::take.out + torch.trace: {f32, f64, i16, i32, i64, i8, u8}, # aten::trace + torch.var_mean: {bf16, f16, f32, f64}, # aten::var_mean.correction + torch.vdot: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::vdot + torch.qr: {f32, f64}, + torch.ormqr: {f32, f64}, + torch.lu_solve: {f32, f64}, + torch.cholesky: {f32, f64}, # aten::cholesky, aten::cholesky.out + torch.cholesky_inverse: {f32, f64}, # aten::cholesky_inverse, aten::cholesky_inverse.out + torch.cholesky_solve: {f32, f64}, # aten::_cholesky_solve_helper + torch.eig: {f32, f64}, # aten::_local_scalar_dense + torch.geqrf: {f32, f64}, # aten::geqrf + torch.linalg.cholesky: {f32, f64}, # aten::linalg_cholesky_ex, aten::linalg_cholesky_ex.L + torch.linalg.cholesky_ex: {f32, f64}, # aten::linalg_cholesky_ex + torch.linalg.det: {f32, f64}, # aten::_det_lu_based_helper + torch.linalg.eig: {f32, f64}, # aten::linalg_eig + torch.linalg.eigh: {f32, f64}, + torch.linalg.eigvals: {f32, f64}, + torch.linalg.eigvalsh: {f32, f64}, # aten::linalg_eigvalsh.out + torch.linalg.householder_product: {f32, f64}, # aten::linalg_householder_product + torch.linalg.inv: {f32, f64}, # aten::_local_scalar_dense + torch.linalg.ldl_factor: {f32, f64}, # aten::_local_scalar_dense + torch.linalg.lstsq: {f32, f64}, # aten::linalg_lstsq.out + torch.linalg.lu_factor: {f32, f64}, # aten::_local_scalar_dense + torch.linalg.slogdet: {f32, f64}, # aten::linalg_slogdet + torch.linalg.solve: {f32, f64}, # aten::linalg_solve, aten::linalg_solve.out + torch.linalg.solve_triangular: {f32, f64}, # aten::linalg_solve_triangular + torch.linalg.tensorinv: {f32, f64}, # aten::_local_scalar_dense + torch.linalg.tensorsolve: {f32, f64}, # aten::linalg_solve + torch.logdet: {f32, f64}, # aten::_local_scalar_dense, aten::nonzero } -# Only some overloads/configurations are covered with meta tensors, -# so we can't use these to toggle expected failure. Try to prioritize these -overload_exclude_set = { - torch.clamp, # MISSING aten::clamp.Tensor - torch.nn.functional.interpolate, # MISSING aten::upsample_nearest3d.vec - torch.nn.functional.upsample_nearest, # MISSING aten::upsample_nearest3d.vec - torch.nn.functional.pad, # MISSING aten::reflection_pad2d - torch.remainder, # MISSING aten::remainder.Scalar_Tensor - torch.linalg.matrix_rank, # MISSING aten::linalg_eigh - torch.diff, # MISSING aten::logical_xor.out - torch.linalg.pinv, # CompositeExplicitAutograd but mH fails +""" +# This is some sample code for how we could dump these dicts into YAML +# file for easier reading/writing +import yaml +print(yaml.dump( + {resolve_name(k): [dtype_abbrs[d] for d in v] + for k, v in meta_function_expected_failures.items()}, default_flow_style=None)) +import sys +sys.exit() +""" + +meta_function_skips = { + torch.Tensor.__getitem__: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8, c32}, + torch.Tensor.__rmatmul__: {bf16, f32, f64, i16, i32, i64, i8, u8}, + torch.index_reduce: {bf16, f16, f32, f64}, + torch.addr: {b8}, + torch.aminmax: {b8, f32, f64, i16, i32, i64, i8, u8}, + torch.bernoulli: {bf16, f32, f64}, + torch.conj_physical: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, + torch.cummax: {b8, bf16, f32, f64, i16, i32, i64, i8, u8}, + torch.cummin: {b8, bf16, f32, f64, i16, i32, i64, i8, u8}, + torch.diff: {b8}, + torch.functional.cdist: {f32, f64}, + torch.functional.tensordot: {bf16, f32, f64, i16, i32, i64, i8, u8}, + torch.index_add: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, + torch.inner: {bf16, f32, f64, i16, i32, i64, i8, u8}, + torch.logical_not: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, + torch.logical_xor: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, + torch.logit: {b8, bf16, f32, f64, i16, i32, i64, i8, u8}, + torch.matmul: {bf16, f32, f64, i16, i32, i64, i8, u8}, + torch.nn.functional.adaptive_avg_pool1d: {bf16, f32, f64}, + torch.nn.functional.adaptive_avg_pool3d: {f16, f32, f64}, + torch.nn.functional.batch_norm: {f32, f64}, + torch.nn.functional.cross_entropy: {bf16, f32, f64}, + torch.nn.functional.interpolate: {bf16, f32, f64, u8}, + torch.nn.functional.nll_loss: {bf16, f32, f64}, + torch.nn.functional.pad: {f32, f64}, + torch.normal: {bf16, f16, f32, f64}, + torch.prod: {b8, f32, f64, i16, i32, i64, i8, u8}, + torch.square: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, + torch.tensor_split: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, + torch.nn.functional.logsigmoid: {bf16, f16, f32, f64}, # logsigmoid.output + torch.inverse: {f32, f64}, + torch.linalg.matrix_power: {f32, f64}, + torch.linalg.matrix_rank: {f32, f64}, + torch.linalg.pinv: {f32, f64}, + torch.empty: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8}, } -# These are fine in OpInfo tests, but triggered errors in full test suite -# crossref testing, which means there is probably not enough coverage from -# OpInfo. Patch in https://github.com/pytorch/pytorch/pull/75994 and find -# out where these fails come from. -suspicious_exclude_set = { - torch.add, # MISSING aten::_local_scalar_dense - torch.cat, # MISSING aten::_local_scalar_dense - torch.cumprod, # MISSING aten::logical_and.out - torch.cumsum, # MISSING aten::_local_scalar_dense - torch.functional.norm, # MISSING aten::isnan +meta_function_device_expected_failures = defaultdict(dict) +meta_function_device_skips = defaultdict(dict) - # RuntimeError: Expected 3D or 4D (batch mode) tensor with optional 0 dim - # batch size for input, but got:[1, 1, 0] - # in test_nn.py TestNNDeviceTypeCPU.test_max_pool1d_corner_cases_cpu_float64 - torch.nn.functional.max_pool1d, - - # Factory functions need tricky kwarg handling - torch.zeros_like, +meta_function_device_expected_failures['cpu'] = { } -# These also are known to not work, but they fail in a more special way -# than the regular "Meta not implemented for aten op" way -meta_exclude_set |= { - # Convolutions have a special error message - torch.nn.functional.conv1d, - torch.nn.functional.conv2d, - torch.nn.functional.conv3d, - torch.nn.functional.conv_transpose1d, - torch.nn.functional.conv_transpose2d, - torch.nn.functional.conv_transpose3d, - # complex stuff handle it specially - torch.view_as_complex, - torch.view_as_real, - # These operators happen very frequently, although they should - # work with meta we intentionally don't test them to speed - # up the test suite - torch.Tensor.__getitem__, - torch.Tensor.__rsub__, - torch.Tensor.__setitem__, - torch.Tensor.add, - torch.Tensor.add_, - torch.Tensor.clone, - torch.Tensor.detach, - torch.Tensor.div, - torch.Tensor.mul, - torch.Tensor.reshape, - torch.Tensor.sub, - torch.Tensor.sum, - torch.rand, - # These correctly report NotImplemented but they don't print - # correctly from resolve_name - torch.ops.quantized.linear_dynamic, - torch._VF.unique_dim, - torch._C._nn.binary_cross_entropy, - torch._C._nn.adaptive_avg_pool2d, - torch._C._nn._test_optional_filled_intlist, - torch._C._nn._test_optional_floatlist, - torch._C._nn._test_optional_intlist, - # Meta tensors don't support storage Python bindings at the - # moment, to be fixed - torch.Tensor.storage, - torch.Tensor.storage_type, - torch.Tensor.share_memory_, - # Weird stuff that hypothetically should work but it's weird - torch._make_dual, - torch._unpack_dual, # fails because we don't preserve forward ad tangent in test code - # These functions cannot, even in principle, be implemented on meta - # tensors (because they involve accessing data somehow), so don't test - # them. - torch.Tensor.__bool__, - torch.Tensor.__float__, - torch.Tensor.__int__, - torch.Tensor.__complex__, - torch.Tensor.__index__, - torch.Tensor.__contains__, - torch.Tensor.cpu, - torch.Tensor.to, - torch.Tensor.tolist, - torch.Tensor.unbind, - torch.Tensor.item, - torch.Tensor.is_nonzero, - torch.Tensor.copy_, - torch.Tensor.numpy, - torch.Tensor.allclose, - torch.Tensor.argwhere, - torch.allclose, - torch.argwhere, - torch.tensor_split, - torch.Tensor.tensor_split, - torch.Tensor.__array__, # doesn't raise NotImplementedError - torch.Tensor.__dlpack_device__, # doesn't raise NotImplementedError - torch.Tensor.__dlpack__, # doesn't raise NotImplementedError - torch.to_dlpack, # doesn't raise NotImplementedError - # Utility functions that get frequently invoked; don't test - torch.Tensor.__format__, - torch.Tensor.__repr__, - # These are getters/setters for properties on tensors; it's not - # really useful to test meta tensors on them - torch.Tensor.device.__get__, - torch.Tensor.dtype.__get__, - torch.Tensor.grad.__get__, - torch.Tensor.grad.__set__, - torch.Tensor.is_sparse.__get__, - torch.Tensor.layout.__get__, - torch.Tensor.shape.__get__, - torch.Tensor.requires_grad.__get__, - torch.Tensor.requires_grad.__set__, - torch.Tensor.data.__get__, - torch.Tensor.data.__set__, - torch.Tensor._base.__get__, - torch.Tensor.is_shared, - torch.Tensor.imag.__get__, - torch.Tensor.real.__get__, - torch.Tensor.__setstate__, - torch.Tensor.is_complex, - torch.Tensor.is_floating_point, - torch.Tensor.numel, - torch.Tensor.requires_grad_, - torch.Tensor.size, - # These perturb RNG and can cause tests to fail, so don't run - # them (TODO: this is not a complete list) - torch.randint, - torch.randn, - # Indirect use of conjugate fallback - torch.fft.hfft, - # These don't raise NotImplementedError, which suggests something - # is wrong with how they're registered with the dispatcher - torch.fbgemm_pack_gemm_matrix_fp16, - torch.fbgemm_pack_quantized_matrix, - torch.fbgemm_linear_fp16_weight, - torch._empty_per_channel_affine_quantized, - torch.fbgemm_linear_int8_weight, - torch._grid_sampler_2d_cpu_fallback, # WAT - torch._nnpack_spatial_convolution, - torch.lstm, - torch.Tensor.conj_physical_, - torch.rnn_tanh, - torch.fbgemm_linear_quantize_weight, - torch._reshape_from_tensor, - torch.gru, - torch.Tensor.unflatten, - torch._saturate_weight_to_fp16, - torch.choose_qparams_optimized, - torch._validate_sparse_coo_tensor_args, - torch.sparse.mm, - torch.Tensor.new, - torch.Tensor.resize, # WTF is this - torch._sobol_engine_initialize_state_, - torch._sobol_engine_draw, - torch._sobol_engine_scramble_, - torch._sobol_engine_ff_, - torch._pack_padded_sequence, - torch._pad_packed_sequence, - torch.sparse_coo_tensor, - torch.linalg.ldl_factor, - torch.index_reduce, - # IndexError: select() cannot be applied to a 0-dim tensor. - # e.g. test_fn_fwgrad_bwgrad_index_add_cpu_complex128 (__main__.TestGradientsCPU) - torch.index_add, - torch.Tensor.index_add, - torch.Tensor.index_add_, - # Can't copy out of meta tensor - torch.linalg.eigvals, - torch.linalg.lu_factor, - torch.nn.functional.ctc_loss, - # Our conversion to meta is not accurate enough (doesn't - # preserve storage_offset, e.g.) - torch.Tensor.as_strided, - # This one segfaults when you call it - torch.Tensor.type, - # We don't clone autograd history, so this will generally not work - torch.autograd.grad, - torch.Tensor.backward, - torch.Tensor.__deepcopy__, - # Don't do factories - torch.ones, - torch.full, - torch.empty, - torch.randperm, - torch.logspace, - torch.zeros, - torch.arange, - torch.vander, - torch.as_tensor, - torch.tensor, - torch.randn_like, - torch.sparse_csr_tensor, - torch._sparse_coo_tensor_unsafe, - torch._sparse_csr_tensor_unsafe, - torch._validate_sparse_csr_tensor_args, +meta_function_device_expected_failures['cuda'] = { + torch.addbmm: {f16}, # aten::addbmm, aten::addbmm.out + torch.corrcoef: {bf16, f16}, # aten::_local_scalar_dense + torch.cov: {f16}, # aten::_local_scalar_dense + torch.diag: {bf16, f16}, # aten::diag.out + torch.diagflat: {bf16, f16}, # aten::diag.out + torch.dot: {f16}, # aten::dot + torch.fft.fft2: {c32, f16}, # aten::_fft_c2c, aten::_fft_c2c.out + torch.fft.fft: {c32, f16}, # aten::_fft_c2c, aten::_fft_c2c.out + torch.fft.fftn: {c32, f16}, # aten::_fft_c2c, aten::_fft_c2c.out + torch.fft.hfft2: {c32, f16}, # aten::_fft_c2c + torch.fft.hfft: {c32, f16}, + torch.fft.hfftn: {c32, f16}, # aten::_fft_c2c + torch.fft.ifft2: {c32, f16}, # aten::_fft_c2c, aten::_fft_c2c.out + torch.fft.ifft: {c32, f16}, # aten::_fft_c2c, aten::_fft_c2c.out + torch.fft.ifftn: {c32, f16}, # aten::_fft_c2c, aten::_fft_c2c.out + torch.fft.ihfft2: {f16}, + torch.fft.ihfft: {f16}, + torch.fft.ihfftn: {f16}, + torch.fft.irfft2: {c32, f16}, # aten::_fft_c2r, aten::_fft_c2r.out + torch.fft.irfft: {c32, f16}, # aten::_fft_c2r, aten::_fft_c2r.out + torch.fft.irfftn: {c32, f16}, # aten::_fft_c2r, aten::_fft_c2r.out + torch.fft.rfft2: {f16}, + torch.fft.rfft: {f16}, + torch.fft.rfftn: {f16}, + torch.functional.unique: {f16}, # aten::_unique2, aten::unique_dim + torch.functional.unique_consecutive: {f16}, # aten::unique_consecutive + torch.geqrf: {f32, f64}, # aten::geqrf + torch.histc: {i16, i32, i64, i8}, # aten::histc, aten::histc.out + torch.kthvalue: {f16}, # aten::kthvalue.values + torch.linalg.cholesky: {f32, f64}, # aten::linalg_cholesky_ex, aten::linalg_cholesky_ex.L + torch.linalg.cholesky_ex: {f32, f64}, # aten::linalg_cholesky_ex + torch.linalg.householder_product: {f32, f64}, # aten::linalg_householder_product, aten::linalg_householder_product.out + torch.linalg.inv: {f32, f64}, # aten::_local_scalar_dense + torch.linalg.ldl_factor: {f32, f64}, # aten::_local_scalar_dense + torch.linalg.lu_factor: {f32, f64}, # aten::_local_scalar_dense + torch.linalg.solve_triangular: {f32, f64}, # aten::linalg_solve_triangular, aten::linalg_solve_triangular.out + torch.linalg.tensorinv: {f32, f64}, # aten::_local_scalar_dense + torch.logcumsumexp: {bf16, f16}, # aten::_logcumsumexp, aten::_logcumsumexp.out + torch.matrix_exp: {f16}, # aten::linalg_matrix_exp + torch.median: {f16}, # aten::median, aten::median.dim_values + torch.multinomial: {f16}, # aten::multinomial, aten::multinomial.out + torch.mvlgamma: {f16}, # aten::_local_scalar_dense, aten::mvlgamma.out + torch.nanmedian: {f16}, # aten::nanmedian, aten::nanmedian.dim_values + torch.nn.functional.adaptive_avg_pool2d: {f16}, # aten::_adaptive_avg_pool2d + torch.nn.functional.conv1d: {f16}, + torch.nn.functional.conv2d: {f16}, + torch.nn.functional.conv_transpose1d: {bf16, f16}, + torch.nn.functional.conv_transpose2d: {bf16, f16}, + torch.nn.functional.conv_transpose3d: {bf16, f16}, + torch.nn.functional.embedding_bag: {bf16}, # aten::_embedding_bag_forward_only + torch.nn.functional.gaussian_nll_loss: {f16}, # aten::_local_scalar_dense + torch.nn.functional.grid_sample: {f16}, # aten::grid_sampler_2d, aten::grid_sampler_3d + torch.nn.functional.group_norm: {bf16, f16}, # aten::var_mean.correction + torch.nn.functional.instance_norm: {bf16, f16}, # aten::var_mean.correction + torch.nn.functional.layer_norm: {f16}, + torch.nn.functional.max_pool3d: {bf16, f16}, # aten::max_pool3d_with_indices + torch.nn.functional.max_pool3d_with_indices: {bf16, f16}, # aten::max_pool3d_with_indices + torch.nn.functional.max_unpool1d: {f16}, # aten::max_unpool2d + torch.nn.functional.max_unpool2d: {f16}, # aten::max_unpool2d + torch.nn.functional.max_unpool3d: {f16}, # aten::max_unpool3d + torch.nn.functional.multi_margin_loss: {bf16, f16}, # aten::multi_margin_loss + torch.nn.functional.multilabel_margin_loss: {bf16, f16}, # aten::multilabel_margin_loss_forward + torch.nn.functional.prelu: {f16}, # aten::prelu + torch.nn.functional.relu: {f16}, # aten::relu + torch.nn.functional.rrelu: {f16}, # aten::rrelu_with_noise + torch.ormqr: {f32, f64}, # aten::ormqr, aten::ormqr.out + torch.qr: {f32, f64}, # aten::_linalg_qr_helper + torch.trace: {b8, bf16, f16}, # aten::diag.out + torch.vdot: {f16}, # aten::vdot +} + +meta_function_device_skips['cuda'] = { + torch.Tensor.__getitem__: {c32}, + torch.Tensor.__rmatmul__: {f16}, + torch.bernoulli: {f16}, + torch.cummax: {f16}, + torch.cummin: {f16}, + torch.functional.tensordot: {f16}, + torch.inner: {f16}, + torch.inverse: {f32, f64}, + torch.linalg.matrix_power: {f32, f64}, + torch.linalg.matrix_rank: {f32, f64}, + torch.linalg.svd: {f32, f64}, + torch.logit: {f16}, + torch.matmul: {f16}, + torch.nn.functional.adaptive_avg_pool1d: {f16}, + torch.nn.functional.adaptive_avg_pool3d: {bf16}, + torch.nn.functional.batch_norm: {bf16, f16}, + torch.nn.functional.cross_entropy: {f16}, + torch.nn.functional.interpolate: {f16}, + torch.nn.functional.nll_loss: {f16}, + torch.nn.functional.pad: {f16}, + torch.prod: {bf16, c32, f16}, + torch.svd: {f32, f64}, } # This is a __torch_function__ mode that, when enabled, interposes every @@ -614,187 +771,347 @@ meta_exclude_set |= { # this test file (since I could have just inlined __torch_function__ on the # OpInfo call, and OpInfos generally have very regular inputs), but it will be # useful for more comprehensive testing e.g., as seen in -# https://github.com/pytorch/pytorch/pull/75994 -class MetaCrossRefMode(torch.overrides.TorchFunctionMode): +# https://github.com/pytorch/pytorch/pull/75994 The big benefit is it is +# A LOT more efficient that torch dispatch mode (at the cost of less coverage) +class MetaCrossRefFunctionMode(torch.overrides.TorchFunctionMode): test_case: TestCase - run_excludes_anyway: bool + device_type: str + dtype: torch.dtype - def __init__(self, test_case, *, run_excludes_anyway): + def __init__(self, test_case, *, device, dtype): self.test_case = test_case - self.run_excludes_anyway = run_excludes_anyway + self.device_type = torch.device(device).type + self.dtype = dtype def __torch_function__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} - hit = 0 - miss = 0 + if torch.jit.is_tracing() or isinstance(func, torch.ScriptMethod): + return func(*args, **kwargs) - # Doesn't actually return a storage - @functools.lru_cache(None) - def meta_storage(s): - return torch.empty(s.size(), dtype=s.dtype, device='meta') + if self.dtype in meta_function_skips.get(func, set()): + test_expect = TestExpect.SKIP + elif self.dtype in meta_function_device_skips[self.device_type].get(func, set()): + test_expect = TestExpect.SKIP + elif self.dtype in meta_function_expected_failures.get(func, set()): + test_expect = TestExpect.XFAILURE + elif self.dtype in meta_function_device_expected_failures[self.device_type].get(func, set()): + test_expect = TestExpect.XFAILURE + else: + test_expect = TestExpect.SUCCESS - def safe_is_leaf(t): - try: - return t.is_leaf - except RuntimeError: - # inference mode can trigger this - return False - - @functools.lru_cache(None) - def meta_tensor(t): - with torch.inference_mode(t.is_inference()): - s = meta_storage(t.storage()) - is_leaf = safe_is_leaf(t) - if is_leaf or not t._is_view(): - r = torch.empty( - (0,), dtype=t.dtype, device='meta' - ) - r.set_(s, t.storage_offset(), t.size(), t.stride()) - r.requires_grad = t.requires_grad - if not is_leaf and t.requires_grad: - with torch.enable_grad(): - r = r.clone() - else: - base = torch.empty( - (0,), dtype=t.dtype, device='meta' - ) - base.set_(s, 0, s.size(), (1,)) - base.requires_grad = t.requires_grad - with torch.enable_grad(): - if t._is_view() and not safe_is_leaf(t._base): - base = base.clone() - r = base.as_strided(t.size(), t.stride(), t.storage_offset()) - torch._C._set_conj(r, t.is_conj()) - torch._C._set_neg(r, t.is_neg()) - return r - - def to_meta(t): - nonlocal hit, miss - # TODO: zero tensors? We appear to have eliminated them by - # excluding complex for now - if type(t) is torch.Tensor or type(t) is torch.nn.Parameter: - if any([ - t.is_sparse_csr, t.is_sparse, t.is_mkldnn, t.is_quantized, - t.is_nested, torch._is_functional_tensor(t), - # these are supported in meta conversion but the fallbacks - # don't work - t.is_neg(), t.is_conj(), - # conjugate fallback does not support meta tensors - t.dtype in (torch.complex128, torch.complex64), - ]): - # TODO: sparse should support meta - # NB technically to('meta') does work but our logging - # instrumentation will see the meta conversions and the - # tests all break so we just exclude this. In any case - # the to conversion isn't really right anyhow. - miss += 1 - return t - elif any([ - t.device.type in ("lazy", "meta"), t.is_complex(), - # We need a way to test if a tensor is batched but there - # is no official APi to do it - # torch._C._is_batched(t), - ]): - # TODO: this stuff should support storage - # (well, maybe not batched) - hit += 1 - return t.to("meta") - else: - hit += 1 - r = meta_tensor(t) - if type(t) is torch.nn.Parameter: - r = torch.nn.Parameter(r, requires_grad=r.requires_grad) - return r - elif torch.overrides.is_tensor_like(t): - # Blindly converting tensor subclasses to meta can cause - # unpredictable problems; e.g., FX tests will trace meta - # tensors into their trace / some subclasses don't correctly - # support meta. Trying to YOLO this is more trouble than it's - # worth. - miss += 1 - return t - else: - # non-Tensor types don't count as hit or miss - return t - - do_meta = ( - (self.run_excludes_anyway or func not in meta_exclude_set) and - not torch.jit.is_tracing() and - not isinstance(func, torch.ScriptMethod) + return run_meta_crossref( + self.test_case, test_expect, func, args, + kwargs, dtype=self.dtype, device_type=self.device_type ) - if do_meta: - try: - meta_args = tree_map(to_meta, args) - meta_kwargs = tree_map(to_meta, kwargs) - except Exception as e: - raise RuntimeError( - f"failed to convert args to meta; " - f"originally (*{args}, **{kwargs})") from e +aten = torch.ops.aten - rs = func(*args, **kwargs) +# these always fail +meta_dispatch_expected_failures = { + aten._adaptive_avg_pool2d.default: {bf16, f64, f32}, + aten._adaptive_avg_pool3d.default: {f16, f64, f32}, + aten._cdist_forward.default: {f64, f32}, + aten._conj_physical.default: {c32}, + aten._convolution.default: {c64, i64, f64, c128, bf16, f32}, + aten._ctc_loss.default: {f64, f32}, + aten._embedding_bag_forward_only.default: {f16, f64, f32}, + aten._fft_r2c.default: {i64, u8, b8, f32, i8, f64, i16, i32}, + aten._histogramdd_bin_edges.default: {f64, f32}, + aten._histogramdd_from_bin_cts.default: {f64, f32}, + aten._histogramdd_from_bin_tensors.default: {f64, f32}, + aten._local_scalar_dense.default: {c64, i64, c128, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten._pdist_forward.default: {f64, f32}, + aten._unique2.default: {i64, bf16, u8, b8, f32, i8, f64, i16, i32}, + aten.addbmm.default: {i64, bf16, u8, f32, i8, f64, i16, i32}, + aten.addbmm.out: {i64, bf16, u8, f32, i8, f64, i16, i32}, + aten.angle.default: {c32, i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.angle.out: {c32, i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.bernoulli.out: {bf16, f64, f32}, + aten.bincount.default: {i8, i64, i16, u8, i32}, + aten.bucketize.Tensor: {i64, bf16, f16, u8, f32, i8, f64, i16, i32}, + aten.bucketize.Tensor_out: {i64, bf16, f16, u8, f32, i8, f64, i16, i32}, + aten.col2im.default: {c64, f32, f64, c128}, + aten.complex.default: {c64, f64, c128, f16, f32}, + aten.complex.out: {f16}, + aten.conj_physical.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, c32, i32}, + aten.convolution.default: {c64, i64, f64, c128, bf16, f32}, + aten.count_nonzero.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.count_nonzero.dim_IntList: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.diag.default: {i64, u8, b8, f32, i8, f64, i16, i32, bf16}, + aten.diag.out: {bf16, i64, u8, b8, f32, i8, f64, i16, i32}, + aten.dot.default: {i64, bf16, u8, f32, i8, f64, i16, i32}, + aten.dot.out: {i64, bf16, u8, f32, i8, f64, i16, i32}, + aten.floor_divide.default: {i64, bf16, f16, u8, f32, i8, f64, i16, i32}, + aten.floor_divide.out: {i64, bf16, f16, u8, f32, i8, f64, i16, i32}, + aten.frexp.Tensor: {bf16, f16, f64, f32}, + aten.grid_sampler_2d.default: {f64, f32}, + aten.grid_sampler_3d.default: {f64, f32}, + aten.histc.default: {bf16, f64, f32}, + aten.histc.out: {bf16, f64, f32}, + aten.histogram.bin_ct: {f64, f32}, + aten.histogram.bins_tensor: {f64, f32}, + aten.im2col.default: {bf16, f16, f64, f32}, + aten.index.Tensor: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32, c32}, + aten.kthvalue.default: {i64, bf16, u8, f32, i8, f64, i16, i32}, + aten.linalg_matrix_exp.default: {bf16, f64, f32}, + aten.log_sigmoid_forward.output: {bf16, f64, f32}, + aten.logcumsumexp.default: {bf16, f64, f32}, + aten.logcumsumexp.out: {bf16, f64, f32}, + aten.logical_not.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.logical_not_.default: {bf16, f16, f64, f32}, + aten.logical_xor.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.logit.out: {i64, bf16, u8, b8, f32, i8, f64, i16, i32}, + aten.masked_select.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.masked_select.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.max_pool3d_with_indices.default: {f64, f32}, + aten.max_unpool2d.default: {f64, f32}, + aten.max_unpool3d.default: {f64, f32}, + aten.median.default: {i64, bf16, u8, f32, i8, f64, i16, i32}, + aten.median.dim: {i64, bf16, u8, f32, i8, f64, i16, i32}, + aten.mode.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.multi_margin_loss.default: {f64, f32}, + aten.multilabel_margin_loss_forward.default: {f64, f32}, + aten.multinomial.default: {bf16, f64, f32}, + aten.multinomial.out: {bf16, f64, f32}, + aten.mvlgamma.default: {i64, bf16, u8, f32, i8, f64, i16, i32}, + aten.mvlgamma.out: {i64, bf16, u8, f32, i8, f64, i16, i32}, + aten.nan_to_num.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.nan_to_num.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.nanmedian.default: {i64, bf16, u8, f32, i8, f64, i16, i32}, + aten.nanmedian.dim: {i64, bf16, u8, f32, i8, f64, i16, i32}, + aten.nansum.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.nansum.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.native_group_norm.default: {bf16, f64, f32}, + aten.nll_loss2d_forward.default: {bf16, f64, f32}, + aten.nonzero.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.nonzero.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.normal.Tensor_Tensor: {bf16, f16, f64, f32}, + aten.normal.Tensor_Tensor_out: {bf16, f16, f64, f32}, + aten.normal.float_Tensor: {bf16, f16, f64, f32}, + aten.normal.float_Tensor_out: {bf16, f16, f64, f32}, + aten.polar.default: {f64, f32}, + aten.prelu.default: {bf16, f64, f32}, + aten.prod.default: {i64, u8, b8, f32, i8, f64, i16, i32}, + aten.reflection_pad2d.default: {f64, f32}, + aten.relu.default: {i64, bf16, u8, f32, i8, f64, i16, i32}, + aten.repeat_interleave.Tensor: {c64, i64, c128, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.roll.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.rrelu_with_noise.default: {bf16, f64, f32}, + aten.searchsorted.Tensor: {i64, bf16, f16, u8, f32, i8, f64, i16, i32}, + aten.searchsorted.Tensor_out: {i64, bf16, f16, u8, f32, i8, f64, i16, i32}, + aten.square.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.std_mean.correction: {bf16, f16, f64, f32}, + aten.take.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.take.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.tensordot.out: {i64, bf16, u8, f32, i8, f64, i16, i32}, + aten.to_sparse.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.to_sparse.sparse_dim: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.trace.default: {i8, i64, f64, i16, u8, i32, f32}, + aten.unique_consecutive.default: {i64, bf16, u8, b8, f32, i8, f64, i16, i32}, + aten.unique_dim.default: {i64, bf16, u8, b8, f32, i8, f64, i16, i32}, + aten.upsample_nearest3d.vec: {bf16, u8, f64, f32}, + aten.var_mean.correction: {bf16, f16, f64, f32}, + aten.vdot.default: {i64, bf16, u8, f32, i8, f64, i16, i32}, + aten.vdot.out: {i64, bf16, u8, f32, i8, f64, i16, i32}, + aten._det_lu_based_helper.default: {f32, f64}, # aten::_det_lu_based_helper + aten._linalg_check_errors.default: {c128, c64, f32, f64}, # aten::_local_scalar_dense + aten.cholesky.default: {f32, f64}, # aten::cholesky + aten.cholesky.out: {f32, f64}, # aten::cholesky.out + aten.cholesky_inverse.default: {f32, f64}, # aten::cholesky_inverse + aten.cholesky_inverse.out: {f32, f64}, # aten::cholesky_inverse.out + aten.cholesky_solve.default: {f32, f64}, # aten::_cholesky_solve_helper + aten.cholesky_solve.out: {f32, f64}, # aten::_cholesky_solve_helper + aten.eig.default: {f32, f64}, # aten::_local_scalar_dense + aten.geqrf.default: {f32, f64}, # aten::geqrf + aten.inverse.out: {f32, f64}, # aten::_local_scalar_dense + aten.linalg_cholesky_ex.L: {f32, f64}, # aten::linalg_cholesky_ex.L + aten.linalg_cholesky_ex.default: {f32, f64}, # aten::linalg_cholesky_ex + aten.linalg_eig.default: {f32, f64}, # aten::linalg_eig + aten.linalg_eigh.default: {f32, f64}, + aten.linalg_eigvalsh.out: {f32, f64}, # aten::linalg_eigvalsh.out + aten.linalg_householder_product.default: {f32, f64}, # aten::linalg_householder_product + aten.linalg_householder_product.out: {f32, f64}, # aten::linalg_householder_product.out + aten.linalg_lstsq.default: {f32, f64}, # aten::linalg_lstsq.out + aten.linalg_qr.default: {f32, f64}, # aten::_linalg_qr_helper + aten.linalg_slogdet.default: {f32, f64}, # aten::linalg_slogdet + aten.linalg_solve.default: {f32, f64}, # aten::linalg_solve + aten.linalg_solve.out: {f32, f64}, # aten::linalg_solve.out + aten.linalg_solve_triangular.default: {f32, f64}, # aten::linalg_solve_triangular + aten.linalg_solve_triangular.out: {f32, f64}, # aten::linalg_solve_triangular.out + aten.logdet.default: {f32, f64}, # aten::_local_scalar_dense, aten::nonzero + aten.lu_solve.default: {f32, f64}, # aten::lu_solve + aten.lu_solve.out: {f32, f64}, # aten::lu_solve.out + aten.ormqr.default: {f32, f64}, # aten::ormqr + aten.ormqr.out: {f32, f64}, # aten::ormqr.out + aten.symeig.default: {f32, f64}, # aten::_symeig_helper +} - # TODO: also handle cases where func raise an exception +# these sometimes pass and sometimes fail +meta_dispatch_skips = { + aten.index_reduce.default: {bf16, f16, f64, f32}, + aten.index_reduce.out: {bf16, f16, f64, f32}, + aten._to_copy.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.addr.default: {b8}, + aten.addr.out: {b8}, + aten.aminmax.default: {i64, u8, b8, f32, i8, f64, i16, i32}, + aten.copy_.default: {c32}, + aten.cummax.default: {i64, bf16, u8, b8, f32, i8, f64, i16, i32}, + aten.cummin.default: {i64, bf16, u8, b8, f32, i8, f64, i16, i32}, + aten.index_add.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.index_add.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.isnan.default: {f64, f32}, + aten.mul.Scalar: {i64, bf16, f16, f32, i8, f64, i16, i32}, + aten.native_batch_norm.default: {f64, f32}, + aten.native_layer_norm.default: {bf16, f64, f32}, + aten.slice.Tensor: {c32}, + aten.inverse.default: {f32, f64}, + aten.linalg_pinv.atol_rtol_tensor: {f32, f64}, + aten.linalg_pinv.atol_rtol_tensor_out: {f32, f64}, + aten.empty.memory_format: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8}, +} - # For now, only attempt if we managed to convert all tensor types - # (if any of them failed, we're in a mixed device situation and - # this isn't well supported) - if do_meta and hit > 0 and miss == 0: - try: - # suppress warnings - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - meta_rs = func(*meta_args, **meta_kwargs) - except Exception as e: - suppress = False - """ - # This code can be helpful for full crossref test to filter - # out "pedestrian" omissions - if isinstance(e, NotImplementedError): - m = RE_NOT_IMPLEMENTED_MSG.search(e.args[0]) - if m and m.group(1) not in ("aten::_efficientzerotensor", "aten::view_as_real"): - suppress = True - """ - if not suppress: - raise RuntimeError(f"""\ -failed to run: {func}( - *{meta_args}, - **{meta_kwargs} - )""") from e - else: - def test_assert(cond, msg): - if not cond: - raise RuntimeError(f"""\ -meta disagrees with real impl: -{func}( - *{meta_args}, - **{meta_kwargs} -) = {meta_r} -{msg} -""") - flat_meta_rs, _ = tree_flatten(meta_rs) - flat_rs, _ = tree_flatten(rs) - self.test_case.assertEqual(len(flat_meta_rs), len(flat_rs)) - for i, meta_r, r in zip(range(len(flat_rs)), flat_meta_rs, flat_rs): - if isinstance(r, torch.Tensor): - test_assert(isinstance(meta_r, torch.Tensor), f"but real {i}th result is Tensor") - test_assert(meta_r.dtype == r.dtype, f"but real dtype was {r.dtype}") - test_assert(meta_r.shape == r.shape, f"but real shape was {r.shape}") - # NOTE: this helper is used instead of a direct stride comparison - # because strides of tensors with no elements and dimensions of - # length 1 are not computed consistently - same_strides, idx = prims.utils.check_significant_strides(meta_r, r) - test_assert(same_strides, f"but real stride was {r.stride()}") - test_assert( - meta_r.storage_offset() == r.storage_offset(), - f"but real storage_offset was {r.storage_offset()}") - test_assert(meta_r.requires_grad == r.requires_grad, f"but real requires_grad was {r.requires_grad}") - test_assert(meta_r.is_conj() == r.is_conj(), f"but real is_conj was {r.is_conj()}") - test_assert(meta_r.is_neg() == r.is_neg(), f"but real is_neg was {r.is_neg()}") +meta_dispatch_device_expected_failures = defaultdict(dict) +meta_dispatch_device_skips = defaultdict(dict) - return rs +meta_dispatch_device_expected_failures['cuda'] = { + aten._adaptive_avg_pool2d.default: {f16}, # aten::_adaptive_avg_pool2d + aten._adaptive_avg_pool3d.default: {bf16}, # aten::_adaptive_avg_pool3d + aten._conj_physical.default: {f16}, # aten::conj_physical.out + aten._convolution.default: {f16}, + aten._embedding_bag_forward_only.default: {bf16}, # aten::_embedding_bag_forward_only + aten._fft_c2c.default: {c32, f16}, # aten::_fft_c2c + aten._fft_c2c.out: {c32, f16}, # aten::_fft_c2c.out + aten._fft_c2r.default: {c32, f16}, # aten::_fft_c2r + aten._fft_c2r.out: {c32, f16}, # aten::_fft_c2r.out + aten._fft_r2c.default: {f16}, # aten::_fft_r2c + aten._fft_r2c.out: {f16}, # aten::_fft_r2c.out + aten._linalg_check_errors.default: {c128, c64, f32, f64}, # aten::_local_scalar_dense + aten._unique2.default: {f16}, # aten::_unique2 + aten._use_cudnn_ctc_loss.default: {f32, f64}, # aten::_use_cudnn_ctc_loss + aten.addbmm.default: {f16}, # aten::addbmm + aten.addbmm.out: {f16}, # aten::addbmm.out + aten.bernoulli.out: {f16}, # aten::bernoulli.out + aten.convolution.default: {f16}, + aten.cudnn_grid_sampler.default: {f16, f32, f64}, # aten::cudnn_grid_sampler + aten.diag.default: {f16}, # aten::diag.out + aten.diag.out: {bf16, f16}, # aten::diag.out + aten.dot.default: {f16}, # aten::dot + aten.dot.out: {f16}, # aten::dot + aten.geqrf.default: {f32, f64}, # aten::geqrf + aten.grid_sampler_2d.default: {f16}, # aten::grid_sampler_2d + aten.grid_sampler_3d.default: {f16}, # aten::grid_sampler_3d + aten.histc.default: {i16, i32, i64, i8}, # aten::histc + aten.histc.out: {i16, i32, i64, i8}, # aten::histc.out + aten.index.Tensor: {c32}, # aten::index.Tensor + aten.inverse.out: {f32, f64}, # aten::_local_scalar_dense + aten.kthvalue.default: {f16}, # aten::kthvalue.values + aten.linalg_cholesky_ex.L: {f32, f64}, # aten::linalg_cholesky_ex.L + aten.linalg_cholesky_ex.default: {f32, f64}, # aten::linalg_cholesky_ex + aten.linalg_eigvalsh.out: {f32, f64}, # aten::linalg_eigvalsh.out + aten.linalg_householder_product.default: {f32, f64}, # aten::linalg_householder_product + aten.linalg_householder_product.out: {f32, f64}, # aten::linalg_householder_product.out + aten.linalg_matrix_exp.default: {f16}, # aten::linalg_matrix_exp + aten.linalg_qr.default: {f32, f64}, # aten::_linalg_qr_helper + aten.linalg_solve_triangular.default: {f32, f64}, # aten::linalg_solve_triangular + aten.linalg_solve_triangular.out: {f32, f64}, # aten::linalg_solve_triangular.out + aten.log_sigmoid_forward.default: {bf16, f16, f64, f32}, + aten.log_sigmoid_forward.output: {f16}, # aten::log_sigmoid_forward.output + aten.logcumsumexp.default: {bf16, f16}, # aten::_logcumsumexp + aten.logcumsumexp.out: {bf16, f16}, # aten::_logcumsumexp.out + aten.logit.out: {f16}, + aten.max_pool3d_with_indices.default: {bf16, f16}, # aten::max_pool3d_with_indices + aten.max_unpool2d.default: {f16}, # aten::max_unpool2d + aten.max_unpool3d.default: {f16}, # aten::max_unpool3d + aten.median.default: {f16}, # aten::median + aten.median.dim: {f16}, # aten::median.dim_values + aten.multi_margin_loss.default: {bf16, f16}, # aten::multi_margin_loss + aten.multilabel_margin_loss_forward.default: {bf16, f16}, # aten::multilabel_margin_loss_forward + aten.multinomial.default: {f16}, # aten::multinomial + aten.multinomial.out: {f16}, # aten::multinomial.out + aten.mvlgamma.default: {f16}, # aten::_local_scalar_dense + aten.mvlgamma.out: {f16}, # aten::mvlgamma.out + aten.nanmedian.default: {f16}, # aten::nanmedian + aten.nanmedian.dim: {f16}, # aten::nanmedian.dim_values + aten.native_batch_norm.default: {bf16, f16}, # aten::var_mean.correction + aten.native_dropout.default: {bf16, f16, f32, f64}, + aten.native_group_norm.default: {bf16, f16}, # aten::var_mean.correction + aten.native_layer_norm.default: {f16}, # aten::var_mean.correction + aten.nll_loss2d_forward.default: {f16}, # aten::nll_loss2d_forward + aten.ormqr.default: {f32, f64}, # aten::ormqr + aten.ormqr.out: {f32, f64}, # aten::ormqr.out + aten.prelu.default: {f16}, # aten::prelu + aten.prod.default: {bf16, c32, f16}, # aten::prod + aten.reflection_pad2d.default: {f16}, # aten::reflection_pad2d + aten.relu.default: {f16}, # aten::relu + aten.rrelu_with_noise.default: {f16}, # aten::rrelu_with_noise + aten.tensordot.out: {f16}, # aten::tensordot.out + aten.trace.default: {b8, bf16, f16}, # aten::diag.out + aten.unique_consecutive.default: {f16}, # aten::unique_consecutive + aten.unique_dim.default: {f16}, # aten::unique_dim + aten.upsample_nearest3d.vec: {f16}, # aten::upsample_nearest3d.vec + aten.vdot.default: {f16}, # aten::vdot + aten.vdot.out: {f16}, # aten::vdot +} +meta_dispatch_device_skips['cuda'] = { + aten._conj.default: {c32, f16}, + aten._linalg_svd.default: {f32, f64}, + aten.cudnn_batch_norm.default: {f32, f64}, + aten.cummax.default: {f16}, + aten.cummin.default: {f16}, + aten.inverse.default: {f32, f64}, + aten.slice.Tensor: {f16}, + # ROCm stuff; technically this should be expected failure but it's + # not worth it; these should get unified anyway + aten.miopen_batch_norm.default: {f32}, +} + +class MetaCrossRefDispatchMode(torch.utils._python_dispatch.TorchDispatchMode): + test_case: TestCase + device: torch.device + dtype: torch.dtype + + def __init__(self, test_case, *, device, dtype): + self.test_case = test_case + # save TLS + self.precision = test_case.precision + self.rel_tol = test_case.rel_tol + self.device_type = torch.device(device).type + self.dtype = dtype + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + + self.test_case.precision = self.precision + self.test_case.rel_tol = self.rel_tol + + if self.dtype in meta_dispatch_skips.get(func, set()): + test_expect = TestExpect.SKIP + elif self.dtype in meta_dispatch_device_skips[self.device_type].get(func, set()): + test_expect = TestExpect.SKIP + elif self.dtype in meta_dispatch_expected_failures.get(func, set()): + test_expect = TestExpect.XFAILURE + elif self.dtype in meta_dispatch_device_expected_failures[self.device_type].get(func, set()): + test_expect = TestExpect.XFAILURE + else: + test_expect = TestExpect.SUCCESS + + return run_meta_crossref( + self.test_case, + test_expect, + func, + args, + kwargs, + dtype=self.dtype, + device_type=self.device_type, + ) + + +# NB: we're running these tests only on CUDA because there are some +# inconsistencies between CUDA and CPU, and running on CUDA makes it easier +# to ignore the CPU case when inconsistencies arise. Ideally we deal +# with the inconsistencies but this takes time. class TestMeta(TestCase): @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") @onlyCUDA @@ -804,30 +1121,32 @@ class TestMeta(TestCase): def test_meta(self, device, dtype, op): # run the OpInfo sample inputs, cross-referencing them with the # meta implementation and check the results are the same. All - # the heavy lifting happens in MetaCrossRefMode + # the heavy lifting happens in MetaCrossRefFunctionMode func = op.get_op() + samples = op.sample_inputs(device, dtype, requires_grad=False) + for sample_input in samples: + args = [sample_input.input] + list(sample_input.args) + kwargs = sample_input.kwargs + with MetaCrossRefFunctionMode.push(self, dtype=dtype, device=device): + expected = func(*args, **kwargs) + if isinstance(expected, torch.Tensor) and op.supports_out: + func(*args, **kwargs, out=expected) - def do_test(run_excludes_anyway=False): - samples = op.sample_inputs(device, dtype, requires_grad=False) - for sample_input in samples: - args = [sample_input.input] + list(sample_input.args) - kwargs = sample_input.kwargs - with push_torch_function_mode(partial(MetaCrossRefMode, self, run_excludes_anyway=run_excludes_anyway)): - expected = func(*args, **kwargs) - if isinstance(expected, torch.Tensor) and op.supports_out: - func(*args, **kwargs, out=expected) - - if func in overload_exclude_set: - self.skipTest('permanently excluded') - elif func in meta_exclude_set and dtype not in (torch.complex128, torch.complex64): - try: - do_test(run_excludes_anyway=True) - except Exception: - pass - else: - self.fail('expected failure, but succeeded') - else: - do_test() + @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") + @onlyCUDA + @skipIfCrossRef + @suppress_warnings + @ops(op_db) + def test_dispatch_meta(self, device, dtype, op): + func = op.get_op() + samples = op.sample_inputs(device, dtype, requires_grad=False) + for sample_input in samples: + args = [sample_input.input] + list(sample_input.args) + kwargs = sample_input.kwargs + with MetaCrossRefDispatchMode.push(self, dtype=dtype, device=device): + expected = func(*args, **kwargs) + if isinstance(expected, torch.Tensor) and op.supports_out: + func(*args, **kwargs, out=expected) instantiate_device_type_tests(TestMeta, globals()) diff --git a/test/test_overrides.py b/test/test_overrides.py index 378dd72b9622..d208a9201729 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -1087,6 +1087,16 @@ class TestDisabledTorchFunction(TestCase): self.assertEqual(torch.nn.functional.linear(inp, t1, t2), "called") self.assertEqual(torch.nn.functional.linear(inp, t2, t1), "called") +class TestResolveName(TestCase): + def test_resolve_name(self): + for cs in get_overridable_functions().values(): + for c in cs: + self.assertEqual( + eval(torch.overrides.resolve_name(c)), + c, + msg=f"{c}, {torch.overrides.resolve_name(c)}" + ) + class TestTorchFunctionWarning(TestCase): def test_warn_on_invalid_torch_function(self): class Bad1(): diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index a19eaf01f48c..7aedd935c697 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -894,8 +894,8 @@ $1 = torch._ops.aten.add.Tensor($0, $0)''') return func(*args, **kwargs) x = torch.randn(1) - with push_torch_dispatch_mode(partial(Logger, "A")): - with push_torch_dispatch_mode(partial(Logger, "B")): + with Logger.push("A"): + with Logger.push("B"): x + x self.assertEqual(logs, ["B", "A"]) diff --git a/test/test_view_ops.py b/test/test_view_ops.py index dddaa03c86b7..424a31e61d24 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -310,11 +310,7 @@ class TestViewOps(TestCase): res = torch.view_as_real(input) self.assertEqual(res[:, :, 0], input.real) self.assertEqual(res[:, :, 1], input.imag) - # TODO: Add torch.ComplexHalfStorage - if dtype != torch.complex32: - self.assertTrue(self.is_view_of(t, res)) - else: - self.assertRaises(RuntimeError, lambda: self.is_view_of(t, res)) + self.assertTrue(self.is_view_of(t, res)) fn() fn(contiguous_input=False) @@ -322,21 +318,13 @@ class TestViewOps(TestCase): # tensor with zero elements x = torch.tensor([], dtype=dtype, device=device) res = torch.view_as_real(x) - # TODO: Add torch.ComplexHalfStorage - if dtype != torch.complex32: - self.assertTrue(self.is_view_of(x, res)) - else: - self.assertRaises(RuntimeError, lambda: self.is_view_of(x, res)) + self.assertTrue(self.is_view_of(x, res)) self.assertEqual(res.shape, torch.Size([0, 2])) # tensor with zero dim x = torch.tensor(2 + 3j, dtype=dtype, device=device) res = torch.view_as_real(x) - # TODO: Add torch.ComplexHalfStorage - if dtype != torch.complex32: - self.assertTrue(self.is_view_of(x, res)) - else: - self.assertRaises(RuntimeError, lambda: self.is_view_of(x, res)) + self.assertTrue(self.is_view_of(x, res)) self.assertEqual(res.shape, torch.Size([2])) @onlyNativeDeviceTypes diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index c8dfb87c5974..1dd8faa71739 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -841,6 +841,8 @@ class Generator(object): # Defined in torch/csrc/utils/python_dispatch.cpp def _dispatch_library(kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0) -> Any: ... +def _dispatch_has_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ... +def _dispatch_has_kernel(name: str) -> _bool: ... # Defined in torch/csrc/utils/init.cpp class BenchmarkConfig(object): diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 4d5e0a060c1c..9d72a832538d 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -1,7 +1,7 @@ import torch import torch._ops import torch.library -from typing import Callable, Union, Dict, Sequence +from typing import Callable, Union, Dict, Sequence, List from torch.utils._pytree import tree_map from collections import defaultdict @@ -15,7 +15,7 @@ decomposition_table: Dict[torch._ops.OpOverload, Callable] = {} meta_lib = torch.library.Library("aten", "IMPL", "Meta") -def register_decomposition(aten_op, registry=None, *, register_meta: bool = False): +def register_decomposition(aten_op, registry=None, *, disable_meta: bool = False): """ A decorator to register a function as a decomposition to the Python decomposition table. Use it like this:: @@ -32,9 +32,9 @@ def register_decomposition(aten_op, registry=None, *, register_meta: bool = Fals autograd) and not just backend tracing, where we then need to know if a decomposition can be used to simulate a transform. - If `register_meta` is True, we will also register this function to the - Meta key in the dispatcher, so that it will be used to compute meta - tensors. + By default, if the decomposition is for an operator that doesn't have + a Meta implementation, we will register it to the dispatcher. Use + `disable_meta` to disable this behavior. """ def decomposition_decorator(f): nonlocal registry @@ -53,7 +53,18 @@ def register_decomposition(aten_op, registry=None, *, register_meta: bool = Fals if op_overload in registry: raise RuntimeError(f"duplicate registrations for {op_overload}") registry[op_overload] = f - if register_meta: + # TODO: factor this logic into OpOverload or Library API + name = op_overload._schema.name + if op_overload._schema.overload_name: + name += "." + op_overload._schema.overload_name + if ( + not disable_meta + # TorchScript dumps a bunch of extra nonsense overloads + # which don't have corresponding dispatcher entries, we need + # to filter those out + and torch._C._dispatch_has_kernel(name) + and not torch._C._dispatch_has_kernel_for_dispatch_key(name, 'Meta') + ): meta_lib.impl(op_overload, f) # To handle allowing multiple aten_ops at once diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 87649b042118..5541506b72a5 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -392,7 +392,7 @@ def mse_loss_backward( return norm * (input - target) * grad_output -@register_decomposition(aten.huber_loss, register_meta=True) +@register_decomposition(aten.huber_loss) @pw_cast_for_opmath def huber_loss( self: Tensor, @@ -1125,7 +1125,7 @@ def std_decomposition( # Questionable decompositions # This is only valid if we're running the graph without autograd, such as if the backward pass has been traced. # Note that this decomposition causes issues with in-place ops -@register_decomposition(aten.detach) +@register_decomposition(aten.detach, disable_meta=True) def detach_decomposition(x): return x diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index c012ada17dd4..894baf3605bc 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -247,7 +247,7 @@ def _make_elementwise_unary_reference( *, type_promotion_kind, aten_op=infer_aten_op, - register_meta=False, + disable_meta=False, extra_meta=None, ) -> Callable: @out_wrapper @@ -269,7 +269,7 @@ def _make_elementwise_unary_reference( if aten_op is infer_aten_op: aten_op = getattr(torch.ops.aten, prim.__name__.split(".")[0]) if aten_op is not None: - register_decomposition(aten_op, register_meta=register_meta)(_ref) + register_decomposition(aten_op, disable_meta=disable_meta)(_ref) return _ref @@ -373,7 +373,6 @@ isnan = _make_elementwise_unary_reference( _isnan, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, aten_op=torch.ops.aten.isnan, # prim/aten name mismatch - register_meta=True, ) lgamma = _make_elementwise_unary_reference( @@ -456,7 +455,7 @@ def _make_elementwise_binary_reference( has_out=True, supports_lhs_python_scalar=True, supports_rhs_python_scalar=True, - register_meta=False, + disable_meta=False, ) -> Callable: @elementwise_type_promotion_wrapper( type_promoting_args=("a", "b"), @@ -491,7 +490,7 @@ def _make_elementwise_binary_reference( if aten_op is infer_aten_op: aten_op = getattr(torch.ops.aten, prim.__name__.split(".")[0]) if aten_op is not None: - register_decomposition(aten_op, register_meta=register_meta)(_ref) + register_decomposition(aten_op, disable_meta=disable_meta)(_ref) return _ref @@ -717,7 +716,6 @@ logical_and = _make_elementwise_binary_reference( _logical_and, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, aten_op=torch.ops.aten.logical_and, - register_meta=True, ) @@ -733,7 +731,6 @@ logical_or = _make_elementwise_binary_reference( _logical_or, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, aten_op=torch.ops.aten.logical_or, - register_meta=True, ) # TODO: add docstring @@ -836,7 +833,7 @@ true_divide = _make_elementwise_binary_reference( # https://pytorch.org/docs/stable/generated/torch.where.html # TODO: implement alternate where -@register_decomposition(torch.ops.aten.where, register_meta=True) +@register_decomposition(torch.ops.aten.where) @out_wrapper @elementwise_type_promotion_wrapper( type_promoting_args=("a", "b"), @@ -1090,7 +1087,7 @@ def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorL return prims.collapse(a, start_dim, end_dim + 1) -@register_decomposition(torch.ops.aten.flip, register_meta=True) +@register_decomposition(torch.ops.aten.flip) def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType: if not isinstance(dims, tuple) and not isinstance(dims, list): raise ValueError("dims has to be a sequence of ints") diff --git a/torch/_tensor.py b/torch/_tensor.py index d359a1e93ac4..37383c17af28 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -202,9 +202,6 @@ class Tensor(torch._C._TensorBase): if has_torch_function_unary(self): return handle_torch_function(Tensor.storage, (self,), self) - if self.dtype not in torch.storage._dtype_to_storage_type_map(): - raise RuntimeError(f'unsupported Storage type: {self.dtype}') - return torch._TypedStorage(wrap_storage=self._storage(), dtype=self.dtype) def _reduce_ex_internal(self, proto): diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 0b186a91a147..34a867f0f555 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -202,6 +202,17 @@ void initDispatchBindings(PyObject* module) { c10::Dispatcher::singleton().checkInvariants(); }); + m.def("_dispatch_has_kernel", [](const char* name) -> bool { + auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); + return static_cast(op); + }); + + m.def("_dispatch_has_kernel_for_dispatch_key", [](const char* name, const char* dispatch) -> bool { + auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); + TORCH_CHECK(op, "operator ", name, " does not exist"); + return op->hasKernelForDispatchKey(c10::parseDispatchKey(dispatch)); + }); + m.def("_dispatch_find_dangling_impls", []() -> std::vector { auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls(); diff --git a/torch/overrides.py b/torch/overrides.py index aab5b2f70caf..383061ecfc99 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -26,7 +26,7 @@ import collections import functools import types import warnings -from typing import Dict, Set, List, Any, Callable, Iterable, Type, Iterator +from typing import Dict, Set, List, Any, Callable, Iterable, Type, Iterator, Tuple import contextlib import torch @@ -42,6 +42,7 @@ __all__ = [ "get_testing_overrides", "handle_torch_function", "has_torch_function", + "resolve_name", "is_tensor_like", "is_tensor_method_or_property", "wrap_torch_function", @@ -1556,36 +1557,32 @@ has_torch_function_variadic = _add_docstr( ) @functools.lru_cache(None) -def get_overridable_functions() -> Dict[Any, List[Callable]]: - """List functions that are overridable via __torch_function__ - - Returns - ------- - Dict[Any, List[Callable]] - A dictionary that maps namespaces that contain overridable functions - to functions in that namespace that can be overridden. - """ +def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]: overridable_funcs = collections.defaultdict(list) + index = {} tested_namespaces = [ - (torch, torch.__all__ + dir(torch._C._VariableFunctions)), - (torch.functional, torch.functional.__all__), - (torch.nn.functional, dir(torch.nn.functional)), - (torch.nn.init, dir(torch.nn.init)), - (torch.Tensor, dir(torch.Tensor)), - (torch.linalg, dir(torch.linalg)), - (torch.fft, dir(torch.fft)), - (torch.special, dir(torch.special)), + ("torch", torch, torch.__all__ + dir(torch._C._VariableFunctions)), + ("torch.functional", torch.functional, torch.functional.__all__), + ("torch.nn.functional", torch.nn.functional, dir(torch.nn.functional)), + ("torch.nn.init", torch.nn.init, dir(torch.nn.init)), + ("torch.Tensor", torch.Tensor, dir(torch.Tensor)), + ("torch.linalg", torch.linalg, dir(torch.linalg)), + ("torch.fft", torch.fft, dir(torch.fft)), + ("torch.special", torch.special, dir(torch.special)), ] - for namespace, ns_funcs in tested_namespaces: + for namespace_str, namespace, ns_funcs in tested_namespaces: for func_name in ns_funcs: + ignore = False # ignore private functions or functions that are deleted in torch.__init__ if namespace is not torch.Tensor: - if func_name.startswith('_'): + if func_name.startswith('__'): continue + elif func_name.startswith('_'): + ignore = True elif func_name.endswith('_'): - continue + ignore = True elif not func_name[0].islower(): - continue + ignore = True elif func_name == 'unique_dim': continue else: @@ -1605,6 +1602,10 @@ def get_overridable_functions() -> Dict[Any, List[Callable]]: continue if not callable(func) and hasattr(func, "__get__"): + index[func.__get__] = f"{namespace_str}.{func_name}.__get__" + index[func.__set__] = f"{namespace_str}.{func_name}.__set__" + if ignore: + continue if func.__get__ in get_ignored_functions(): msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions " "but still has an explicit override") @@ -1617,6 +1618,11 @@ def get_overridable_functions() -> Dict[Any, List[Callable]]: if not callable(func): continue + index[func] = f"{namespace_str}.{func_name}" + + if ignore: + continue + # cannot be overriden by __torch_function__ if func in get_ignored_functions(): msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions " @@ -1624,7 +1630,37 @@ def get_overridable_functions() -> Dict[Any, List[Callable]]: assert func not in get_testing_overrides(), msg.format(namespace, func.__name__) continue overridable_funcs[namespace].append(func) - return overridable_funcs + return overridable_funcs, index + +def get_overridable_functions() -> Dict[Any, List[Callable]]: + """List functions that are overridable via __torch_function__ + + Returns + ------- + Dict[Any, List[Callable]] + A dictionary that maps namespaces that contain overridable functions + to functions in that namespace that can be overridden. + """ + return _get_overridable_functions()[0] + +def resolve_name(f): + """Get a human readable string name for a function passed to + __torch_function__ + + Arguments + --------- + callable : Callable + Function to resolve the name of. + + Returns + ------- + str + Name of the function; if eval'ed it should give back the input + function. + """ + if isinstance(f, torch._ops.OpOverload): + return str(f) + return _get_overridable_functions()[1].get(f) @functools.lru_cache(None) def _get_tensor_methods() -> Set[Callable]: @@ -1782,6 +1818,10 @@ class TorchFunctionMode(metaclass=TorchFunctionModeMeta): def __torch_function__(self, func, types, args=(), kwargs=None): raise NotImplementedError() + @classmethod + def push(cls, *args, **kwargs): + return push_torch_function_mode(functools.partial(cls, *args, **kwargs)) + class BaseTorchFunctionMode(TorchFunctionMode): def __torch_function__(self, func, types, args=(), kwargs=None): diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 797772a39e0c..b1a3bea97bca 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -10026,7 +10026,6 @@ op_db: List[OpInfo] = [ # Reference: https://github.com/pytorch/pytorch/issues/50747 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16)), - DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', dtypes=(torch.bool,)), ), sample_inputs_func=sample_inputs_addr, gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), @@ -11441,8 +11440,6 @@ op_db: List[OpInfo] = [ skips=( # Skip since real and imag don't have out variants. DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'), - DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', - dtypes=(torch.complex32,)), )), OpInfo('gradient', dtypes=floating_and_complex_types_and(torch.int8, torch.int16, @@ -12036,6 +12033,7 @@ op_db: List[OpInfo] = [ dtypes=floating_and_complex_types(), supports_forward_ad=True, supports_fwgrad_bwgrad=True, + skips=(skipCPUIfNoLapack,), sample_inputs_func=sample_inputs_lu_unpack), OpInfo('lu', op=torch.lu, @@ -12069,7 +12067,7 @@ op_db: List[OpInfo] = [ # See https://github.com/pytorch/pytorch/issues/66357 check_batched_forward_grad=False, sample_inputs_func=sample_inputs_lu_solve, - decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], skips=( DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='mps', dtypes=[torch.float32]), @@ -12595,7 +12593,6 @@ op_db: List[OpInfo] = [ skips=( # AssertionError: Resizing an out= argument with no elements threw a resize warning! DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'), - DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cpu'), )), OpInfo('as_strided', op=lambda x, size, stride, storage_offset=0: @@ -13334,9 +13331,6 @@ op_db: List[OpInfo] = [ skips=( # Pre-existing condition; Needs to be fixed DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cpu'), - # RuntimeError: "max_pool1d_impl" not implemented for 'BFloat16' - DecorateInfo(unittest.skip("Works on some configs"), 'TestMeta', - 'test_meta', dtypes=(torch.bfloat16,)), DecorateInfo(unittest.skip("Works on some configs"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bfloat16,)), DecorateInfo(unittest.skip("Works on some conifgs"), 'TestCudaFuserOpInfo', @@ -14382,8 +14376,6 @@ op_db: List[OpInfo] = [ skips=( # Skip since real and imag don't have out variants. DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'), - DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', - dtypes=(torch.complex32,)), )), OpInfo('roll', ref=np.roll, @@ -14896,7 +14888,7 @@ op_db: List[OpInfo] = [ supports_forward_ad=True, supports_fwgrad_bwgrad=True, gradcheck_wrapper=lambda *args, **kwargs: gradcheck_wrapper_triangular_input(*args, idx=1, **kwargs), - decorators=[skipCUDAIfNoMagma], + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], skips=( # AssertionError: Scalars are not equal! DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), @@ -15119,12 +15111,7 @@ op_db: List[OpInfo] = [ ref=np.isfinite, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), supports_out=False, - supports_autograd=False, - skips=( - # NotImplementedError: - # Could not run 'aten::view_as_real' with arguments from the 'Meta' backend. - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta", dtypes=(torch.chalf,)), - )), + supports_autograd=False), UnaryUfuncInfo('isinf', ref=np.isinf, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), @@ -15133,8 +15120,6 @@ op_db: List[OpInfo] = [ supports_sparse_csr=True, supports_autograd=False, skips=( - # Could not run 'aten::view_as_real' with arguments from the 'Meta' backend. - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta", dtypes=(torch.chalf,)), # "nonzero_count_cpu" not implemented for 'ComplexHalf' # "nonzero_cuda" not implemented for 'ComplexHalf' DecorateInfo(unittest.expectedFailure, "TestSparseCSR", @@ -15160,12 +15145,7 @@ op_db: List[OpInfo] = [ ref=np.isreal, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), supports_out=False, - supports_autograd=False, - skips=( - # NotImplementedError: - # Could not run 'aten::view_as_real' with arguments from the 'Meta' backend. - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta", dtypes=(torch.chalf,)), - )), + supports_autograd=False), UnaryUfuncInfo('isnan', ref=np.isnan, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), @@ -15198,6 +15178,7 @@ op_db: List[OpInfo] = [ dtypes=floating_and_complex_types(), sample_inputs_func=sample_inputs_linalg_solve_triangular, supports_fwgrad_bwgrad=True, + skips=(skipCPUIfNoLapack,), # linalg.solve_triangular cannot be batched over because of a call to out.copy_(result); supports_forward_ad=True), OpInfo('linalg.matrix_rank', @@ -15242,6 +15223,7 @@ op_db: List[OpInfo] = [ supports_forward_ad=True, supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_linalg_pinv, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], skips=( # errors with "leaked XXXX bytes CUDA memory on device 0" DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),) @@ -15344,9 +15326,6 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', dtypes=(torch.complex128,)), DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), - # stride mismatch - DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cuda', - dtypes=(torch.float32, torch.float64), active_if=not TEST_WITH_ROCM), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='mps', dtypes=[torch.float32]), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', @@ -15371,9 +15350,6 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', dtypes=(torch.complex128,)), DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), - # stride mismatch - DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cuda', - dtypes=(torch.float32, torch.float64), active_if=not TEST_WITH_ROCM), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='mps', dtypes=[torch.float32]), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', @@ -15418,8 +15394,6 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), - # stride mismatch - DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cuda', active_if=not TEST_WITH_ROCM), )), OpInfo('pca_lowrank', op=lambda *args, **kwargs: wrapper_set_seed( @@ -15444,8 +15418,6 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), - # stride mismatch - DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cuda', active_if=not TEST_WITH_ROCM), )), BinaryUfuncInfo('polar', dtypes=floating_types(), @@ -17423,7 +17395,7 @@ op_db: List[OpInfo] = [ sample_inputs_func=sample_inputs_tensorsolve, supports_forward_ad=True, supports_fwgrad_bwgrad=True, - decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver], + decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagma], ), OpInfo( "nn.functional.mse_loss", diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index a39291e6d5c8..c3ed92133056 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -147,6 +147,10 @@ class TorchDispatchMode(metaclass=TorchDispatchModeMeta): def __torch_dispatch__(self, func, types, args=(), kwargs=None): raise NotImplementedError() + @classmethod + def push(cls, *args, **kwargs): + return push_torch_dispatch_mode(functools.partial(cls, *args, **kwargs)) + class BaseTorchDispatchMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None):