Revert "Add dispatch mode testing for meta tensors and other stuff"

This reverts commit c1cdb1216b97970d903a6d6e9e7d0e2b4ffaef46.

Reverted https://github.com/pytorch/pytorch/pull/77477 on behalf of https://github.com/malfet
This commit is contained in:
PyTorch MergeBot
2022-05-18 02:56:48 +00:00
parent c35bd8d423
commit 48581d74ad
16 changed files with 846 additions and 1214 deletions

View File

@ -32,13 +32,13 @@ void max_pool1d_impl(
Tensor& output,
const Tensor& input,
const PoolingParams1D& p) {
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool1d_impl", [&] {
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_pool1d_impl", [&] {
const Tensor in = input.contiguous();
scalar_t* const OP = output.data_ptr<scalar_t>();
const scalar_t* const IP = in.data_ptr<scalar_t>();
// Value used for padding
scalar_t FILL = std::numeric_limits<scalar_t>::has_infinity
constexpr scalar_t FILL = std::numeric_limits<scalar_t>::has_infinity
? -std::numeric_limits<scalar_t>::infinity()
: std::numeric_limits<scalar_t>::lowest();

View File

@ -330,12 +330,12 @@
- func: view_as_real(Tensor(a) self) -> Tensor(a)
variants: function
dispatch:
CPU, CUDA, MPS, Meta: view_as_real
CPU, CUDA, MPS: view_as_real
- func: view_as_complex(Tensor(a) self) -> Tensor(a)
variants: function
dispatch:
CPU, CUDA, Meta: view_as_complex
CPU, CUDA: view_as_complex
- func: sgn(Tensor self) -> Tensor
variants: function, method

View File

@ -14,8 +14,6 @@ Functions
.. autofunction:: get_overridable_functions
.. autofunction:: resolve_name
.. autofunction:: get_testing_overrides
.. autofunction:: handle_torch_function

File diff suppressed because it is too large Load Diff

View File

@ -1087,16 +1087,6 @@ class TestDisabledTorchFunction(TestCase):
self.assertEqual(torch.nn.functional.linear(inp, t1, t2), "called")
self.assertEqual(torch.nn.functional.linear(inp, t2, t1), "called")
class TestResolveName(TestCase):
def test_resolve_name(self):
for cs in get_overridable_functions().values():
for c in cs:
self.assertEqual(
eval(torch.overrides.resolve_name(c)),
c,
msg=f"{c}, {torch.overrides.resolve_name(c)}"
)
class TestTorchFunctionWarning(TestCase):
def test_warn_on_invalid_torch_function(self):
class Bad1():

View File

@ -894,8 +894,8 @@ $1 = torch._ops.aten.add.Tensor($0, $0)''')
return func(*args, **kwargs)
x = torch.randn(1)
with Logger.push("A"):
with Logger.push("B"):
with push_torch_dispatch_mode(partial(Logger, "A")):
with push_torch_dispatch_mode(partial(Logger, "B")):
x + x
self.assertEqual(logs, ["B", "A"])

View File

@ -310,7 +310,11 @@ class TestViewOps(TestCase):
res = torch.view_as_real(input)
self.assertEqual(res[:, :, 0], input.real)
self.assertEqual(res[:, :, 1], input.imag)
self.assertTrue(self.is_view_of(t, res))
# TODO: Add torch.ComplexHalfStorage
if dtype != torch.complex32:
self.assertTrue(self.is_view_of(t, res))
else:
self.assertRaises(RuntimeError, lambda: self.is_view_of(t, res))
fn()
fn(contiguous_input=False)
@ -318,13 +322,21 @@ class TestViewOps(TestCase):
# tensor with zero elements
x = torch.tensor([], dtype=dtype, device=device)
res = torch.view_as_real(x)
self.assertTrue(self.is_view_of(x, res))
# TODO: Add torch.ComplexHalfStorage
if dtype != torch.complex32:
self.assertTrue(self.is_view_of(x, res))
else:
self.assertRaises(RuntimeError, lambda: self.is_view_of(x, res))
self.assertEqual(res.shape, torch.Size([0, 2]))
# tensor with zero dim
x = torch.tensor(2 + 3j, dtype=dtype, device=device)
res = torch.view_as_real(x)
self.assertTrue(self.is_view_of(x, res))
# TODO: Add torch.ComplexHalfStorage
if dtype != torch.complex32:
self.assertTrue(self.is_view_of(x, res))
else:
self.assertRaises(RuntimeError, lambda: self.is_view_of(x, res))
self.assertEqual(res.shape, torch.Size([2]))
@onlyNativeDeviceTypes

View File

@ -841,8 +841,6 @@ class Generator(object):
# Defined in torch/csrc/utils/python_dispatch.cpp
def _dispatch_library(kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0) -> Any: ...
def _dispatch_has_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ...
def _dispatch_has_kernel(name: str) -> _bool: ...
# Defined in torch/csrc/utils/init.cpp
class BenchmarkConfig(object):

View File

@ -1,7 +1,7 @@
import torch
import torch._ops
import torch.library
from typing import Callable, Union, Dict, Sequence, List
from typing import Callable, Union, Dict, Sequence
from torch.utils._pytree import tree_map
from collections import defaultdict
@ -15,7 +15,7 @@ decomposition_table: Dict[torch._ops.OpOverload, Callable] = {}
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
def register_decomposition(aten_op, registry=None, *, disable_meta: bool = False):
def register_decomposition(aten_op, registry=None, *, register_meta: bool = False):
"""
A decorator to register a function as a decomposition to the Python
decomposition table. Use it like this::
@ -32,9 +32,9 @@ def register_decomposition(aten_op, registry=None, *, disable_meta: bool = False
autograd) and not just backend tracing, where we then need to know if a
decomposition can be used to simulate a transform.
By default, if the decomposition is for an operator that doesn't have
a Meta implementation, we will register it to the dispatcher. Use
`disable_meta` to disable this behavior.
If `register_meta` is True, we will also register this function to the
Meta key in the dispatcher, so that it will be used to compute meta
tensors.
"""
def decomposition_decorator(f):
nonlocal registry
@ -53,18 +53,7 @@ def register_decomposition(aten_op, registry=None, *, disable_meta: bool = False
if op_overload in registry:
raise RuntimeError(f"duplicate registrations for {op_overload}")
registry[op_overload] = f
# TODO: factor this logic into OpOverload or Library API
name = op_overload._schema.name
if op_overload._schema.overload_name:
name += "." + op_overload._schema.overload_name
if (
not disable_meta
# TorchScript dumps a bunch of extra nonsense overloads
# which don't have corresponding dispatcher entries, we need
# to filter those out
and torch._C._dispatch_has_kernel(name)
and not torch._C._dispatch_has_kernel_for_dispatch_key(name, 'Meta')
):
if register_meta:
meta_lib.impl(op_overload, f)
# To handle allowing multiple aten_ops at once

View File

@ -391,7 +391,7 @@ def mse_loss_backward(
return norm * (input - target) * grad_output
@register_decomposition(aten.huber_loss)
@register_decomposition(aten.huber_loss, register_meta=True)
@pw_cast_for_opmath
def huber_loss(
self: Tensor,
@ -1124,7 +1124,7 @@ def std_decomposition(
# Questionable decompositions
# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
# Note that this decomposition causes issues with in-place ops
@register_decomposition(aten.detach, disable_meta=True)
@register_decomposition(aten.detach)
def detach_decomposition(x):
return x

View File

@ -236,7 +236,7 @@ infer_aten_op = object()
# TODO: add type promotion support
def _make_elementwise_unary_reference(
prim: Callable, *, type_promotion_kind, aten_op=infer_aten_op, disable_meta=False
prim: Callable, *, type_promotion_kind, aten_op=infer_aten_op, register_meta=False
) -> Callable:
@out_wrapper
@elementwise_type_promotion_wrapper(
@ -248,7 +248,7 @@ def _make_elementwise_unary_reference(
if aten_op is infer_aten_op:
aten_op = getattr(torch.ops.aten, prim.__name__.split(".")[0])
if aten_op is not None:
register_decomposition(aten_op, disable_meta=disable_meta)(_ref)
register_decomposition(aten_op, register_meta=register_meta)(_ref)
return _ref
@ -330,6 +330,7 @@ isnan = _make_elementwise_unary_reference(
_isnan,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
aten_op=torch.ops.aten.isnan, # prim/aten name mismatch
register_meta=True,
)
lgamma = _make_elementwise_unary_reference(
@ -392,7 +393,7 @@ def _make_elementwise_binary_reference(
type_promotion_kind,
aten_op=infer_aten_op,
has_out=True,
disable_meta=False,
register_meta=False,
) -> Callable:
@elementwise_type_promotion_wrapper(
type_promoting_args=("a", "b"), type_promotion_kind=type_promotion_kind
@ -410,7 +411,7 @@ def _make_elementwise_binary_reference(
if aten_op is infer_aten_op:
aten_op = getattr(torch.ops.aten, prim.__name__.split(".")[0])
if aten_op is not None:
register_decomposition(aten_op, disable_meta=disable_meta)(_ref)
register_decomposition(aten_op, register_meta=register_meta)(_ref)
return _ref
@ -467,8 +468,7 @@ bitwise_left_shift = _make_elementwise_binary_reference(
# TODO: add docstring
bitwise_or = _make_elementwise_binary_reference(
prims.bitwise_or,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
prims.bitwise_or, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
# TODO: add docstring
@ -480,8 +480,7 @@ bitwise_right_shift = _make_elementwise_binary_reference(
# TODO: add docstring
bitwise_xor = _make_elementwise_binary_reference(
prims.bitwise_xor,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
prims.bitwise_xor, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
# TODO: add docstring
@ -604,6 +603,7 @@ logical_and = _make_elementwise_binary_reference(
_logical_and,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
aten_op=torch.ops.aten.logical_and,
register_meta=True,
)
@ -619,6 +619,7 @@ logical_or = _make_elementwise_binary_reference(
_logical_or,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
aten_op=torch.ops.aten.logical_or,
register_meta=True,
)
# TODO: add docstring
@ -710,7 +711,7 @@ true_divide = _make_elementwise_binary_reference(
# https://pytorch.org/docs/stable/generated/torch.where.html
# TODO: implement alternate where
@register_decomposition(torch.ops.aten.where)
@register_decomposition(torch.ops.aten.where, register_meta=True)
@out_wrapper
@elementwise_type_promotion_wrapper(
type_promoting_args=("a", "b"),
@ -948,7 +949,7 @@ def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorL
return prims.collapse(a, start_dim, end_dim + 1)
@register_decomposition(torch.ops.aten.flip)
@register_decomposition(torch.ops.aten.flip, register_meta=True)
def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
if not isinstance(dims, tuple) and not isinstance(dims, list):
raise ValueError("dims has to be a sequence of ints")

View File

@ -202,6 +202,9 @@ class Tensor(torch._C._TensorBase):
if has_torch_function_unary(self):
return handle_torch_function(Tensor.storage, (self,), self)
if self.dtype not in torch.storage._dtype_to_storage_type_map():
raise RuntimeError(f'unsupported Storage type: {self.dtype}')
return torch._TypedStorage(wrap_storage=self._storage(), dtype=self.dtype)
def _reduce_ex_internal(self, proto):

View File

@ -202,17 +202,6 @@ void initDispatchBindings(PyObject* module) {
c10::Dispatcher::singleton().checkInvariants();
});
m.def("_dispatch_has_kernel", [](const char* name) -> bool {
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
return static_cast<bool>(op);
});
m.def("_dispatch_has_kernel_for_dispatch_key", [](const char* name, const char* dispatch) -> bool {
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
TORCH_CHECK(op, "operator ", name, " does not exist");
return op->hasKernelForDispatchKey(c10::parseDispatchKey(dispatch));
});
m.def("_dispatch_find_dangling_impls", []() -> std::vector<std::string> {
auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls();

View File

@ -26,7 +26,7 @@ import collections
import functools
import types
import warnings
from typing import Dict, Set, List, Any, Callable, Iterable, Type, Iterator, Tuple
from typing import Dict, Set, List, Any, Callable, Iterable, Type, Iterator
import contextlib
import torch
@ -42,7 +42,6 @@ __all__ = [
"get_testing_overrides",
"handle_torch_function",
"has_torch_function",
"resolve_name",
"is_tensor_like",
"is_tensor_method_or_property",
"wrap_torch_function",
@ -1557,32 +1556,36 @@ has_torch_function_variadic = _add_docstr(
)
@functools.lru_cache(None)
def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]:
def get_overridable_functions() -> Dict[Any, List[Callable]]:
"""List functions that are overridable via __torch_function__
Returns
-------
Dict[Any, List[Callable]]
A dictionary that maps namespaces that contain overridable functions
to functions in that namespace that can be overridden.
"""
overridable_funcs = collections.defaultdict(list)
index = {}
tested_namespaces = [
("torch", torch, torch.__all__ + dir(torch._C._VariableFunctions)),
("torch.functional", torch.functional, torch.functional.__all__),
("torch.nn.functional", torch.nn.functional, dir(torch.nn.functional)),
("torch.nn.init", torch.nn.init, dir(torch.nn.init)),
("torch.Tensor", torch.Tensor, dir(torch.Tensor)),
("torch.linalg", torch.linalg, dir(torch.linalg)),
("torch.fft", torch.fft, dir(torch.fft)),
("torch.special", torch.special, dir(torch.special)),
(torch, torch.__all__ + dir(torch._C._VariableFunctions)),
(torch.functional, torch.functional.__all__),
(torch.nn.functional, dir(torch.nn.functional)),
(torch.nn.init, dir(torch.nn.init)),
(torch.Tensor, dir(torch.Tensor)),
(torch.linalg, dir(torch.linalg)),
(torch.fft, dir(torch.fft)),
(torch.special, dir(torch.special)),
]
for namespace_str, namespace, ns_funcs in tested_namespaces:
for namespace, ns_funcs in tested_namespaces:
for func_name in ns_funcs:
ignore = False
# ignore private functions or functions that are deleted in torch.__init__
if namespace is not torch.Tensor:
if func_name.startswith('__'):
if func_name.startswith('_'):
continue
elif func_name.startswith('_'):
ignore = True
elif func_name.endswith('_'):
ignore = True
continue
elif not func_name[0].islower():
ignore = True
continue
elif func_name == 'unique_dim':
continue
else:
@ -1602,10 +1605,6 @@ def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callab
continue
if not callable(func) and hasattr(func, "__get__"):
index[func.__get__] = f"{namespace_str}.{func_name}.__get__"
index[func.__set__] = f"{namespace_str}.{func_name}.__set__"
if ignore:
continue
if func.__get__ in get_ignored_functions():
msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
"but still has an explicit override")
@ -1618,11 +1617,6 @@ def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callab
if not callable(func):
continue
index[func] = f"{namespace_str}.{func_name}"
if ignore:
continue
# cannot be overriden by __torch_function__
if func in get_ignored_functions():
msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
@ -1630,37 +1624,7 @@ def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callab
assert func not in get_testing_overrides(), msg.format(namespace, func.__name__)
continue
overridable_funcs[namespace].append(func)
return overridable_funcs, index
def get_overridable_functions() -> Dict[Any, List[Callable]]:
"""List functions that are overridable via __torch_function__
Returns
-------
Dict[Any, List[Callable]]
A dictionary that maps namespaces that contain overridable functions
to functions in that namespace that can be overridden.
"""
return _get_overridable_functions()[0]
def resolve_name(f):
"""Get a human readable string name for a function passed to
__torch_function__
Arguments
---------
callable : Callable
Function to resolve the name of.
Returns
-------
str
Name of the function; if eval'ed it should give back the input
function.
"""
if isinstance(f, torch._ops.OpOverload):
return str(f)
return _get_overridable_functions()[1].get(f)
return overridable_funcs
@functools.lru_cache(None)
def _get_tensor_methods() -> Set[Callable]:
@ -1818,10 +1782,6 @@ class TorchFunctionMode(metaclass=TorchFunctionModeMeta):
def __torch_function__(self, func, types, args=(), kwargs=None):
raise NotImplementedError()
@classmethod
def push(cls, *args, **kwargs):
return push_torch_function_mode(functools.partial(cls, *args, **kwargs))
class BaseTorchFunctionMode(TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):

View File

@ -9939,6 +9939,7 @@ op_db: List[OpInfo] = [
# Reference: https://github.com/pytorch/pytorch/issues/50747
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16)),
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', dtypes=(torch.bool,)),
),
sample_inputs_func=sample_inputs_addr,
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
@ -11352,6 +11353,8 @@ op_db: List[OpInfo] = [
skips=(
# Skip since real and imag don't have out variants.
DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta',
dtypes=(torch.complex32,)),
)),
OpInfo('gradient',
dtypes=floating_and_complex_types_and(torch.int8, torch.int16,
@ -11976,7 +11979,6 @@ op_db: List[OpInfo] = [
dtypes=floating_and_complex_types(),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(skipCPUIfNoLapack,),
sample_inputs_func=sample_inputs_lu_unpack),
OpInfo('lu',
op=torch.lu,
@ -12010,7 +12012,7 @@ op_db: List[OpInfo] = [
# See https://github.com/pytorch/pytorch/issues/66357
check_batched_forward_grad=False,
sample_inputs_func=sample_inputs_lu_solve,
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
skips=(
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
device_type='mps', dtypes=[torch.float32]),
@ -12526,6 +12528,7 @@ op_db: List[OpInfo] = [
skips=(
# AssertionError: Resizing an out= argument with no elements threw a resize warning!
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'),
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cpu'),
)),
OpInfo('as_strided',
op=lambda x, size, stride, storage_offset=0:
@ -13253,6 +13256,9 @@ op_db: List[OpInfo] = [
skips=(
# Pre-existing condition; Needs to be fixed
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cpu'),
# RuntimeError: "max_pool1d_impl" not implemented for 'BFloat16'
DecorateInfo(unittest.skip("Works on some configs"), 'TestMeta',
'test_meta', dtypes=(torch.bfloat16,)),
DecorateInfo(unittest.skip("Works on some configs"), 'TestNNCOpInfo',
'test_nnc_correctness', dtypes=(torch.bfloat16,)),
DecorateInfo(unittest.skip("Works on some conifgs"), 'TestCudaFuserOpInfo',
@ -14308,6 +14314,8 @@ op_db: List[OpInfo] = [
skips=(
# Skip since real and imag don't have out variants.
DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta',
dtypes=(torch.complex32,)),
)),
OpInfo('roll',
ref=np.roll,
@ -14823,7 +14831,7 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
gradcheck_wrapper=lambda *args, **kwargs: gradcheck_wrapper_triangular_input(*args, idx=1, **kwargs),
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
decorators=[skipCUDAIfNoMagma],
skips=(
# AssertionError: Scalars are not equal!
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
@ -15125,7 +15133,6 @@ op_db: List[OpInfo] = [
dtypes=floating_and_complex_types(),
sample_inputs_func=sample_inputs_linalg_solve_triangular,
supports_fwgrad_bwgrad=True,
skips=(skipCPUIfNoLapack,),
# linalg.solve_triangular cannot be batched over because of a call to out.copy_(result);
supports_forward_ad=True),
OpInfo('linalg.matrix_rank',
@ -15170,7 +15177,6 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_linalg_pinv,
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
skips=(
# errors with "leaked XXXX bytes CUDA memory on device 0"
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),)
@ -15273,6 +15279,9 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', dtypes=(torch.complex128,)),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
# stride mismatch
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cuda',
dtypes=(torch.float32, torch.float64), active_if=not TEST_WITH_ROCM),
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
device_type='mps', dtypes=[torch.float32]),
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
@ -15297,6 +15306,9 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', dtypes=(torch.complex128,)),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
# stride mismatch
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cuda',
dtypes=(torch.float32, torch.float64), active_if=not TEST_WITH_ROCM),
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
device_type='mps', dtypes=[torch.float32]),
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
@ -15341,6 +15353,8 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
# stride mismatch
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cuda', active_if=not TEST_WITH_ROCM),
)),
OpInfo('pca_lowrank',
op=lambda *args, **kwargs: wrapper_set_seed(
@ -15365,6 +15379,8 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
# stride mismatch
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cuda', active_if=not TEST_WITH_ROCM),
)),
BinaryUfuncInfo('polar',
dtypes=floating_types(),
@ -17317,7 +17333,7 @@ op_db: List[OpInfo] = [
sample_inputs_func=sample_inputs_tensorsolve,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagma],
decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver],
),
OpInfo(
"nn.functional.mse_loss",

View File

@ -147,10 +147,6 @@ class TorchDispatchMode(metaclass=TorchDispatchModeMeta):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise NotImplementedError()
@classmethod
def push(cls, *args, **kwargs):
return push_torch_dispatch_mode(functools.partial(cls, *args, **kwargs))
class BaseTorchDispatchMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):