Revert "Revert "Implement sym_sizes to create proper IR for sym ints representing tensor sizes (#76836)""

This reverts commit c35bd8d423ca53408c3aa39c2280167f3a22cea0.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77719

Approved by: https://github.com/Chillee, https://github.com/malfet
This commit is contained in:
Edward Z. Yang
2022-05-18 08:53:20 -07:00
committed by PyTorch MergeBot
parent befa4e371e
commit 4941e72e40
16 changed files with 1212 additions and 859 deletions

View File

@ -32,13 +32,13 @@ void max_pool1d_impl(
Tensor& output,
const Tensor& input,
const PoolingParams1D& p) {
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_pool1d_impl", [&] {
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool1d_impl", [&] {
const Tensor in = input.contiguous();
scalar_t* const OP = output.data_ptr<scalar_t>();
const scalar_t* const IP = in.data_ptr<scalar_t>();
// Value used for padding
constexpr scalar_t FILL = std::numeric_limits<scalar_t>::has_infinity
scalar_t FILL = std::numeric_limits<scalar_t>::has_infinity
? -std::numeric_limits<scalar_t>::infinity()
: std::numeric_limits<scalar_t>::lowest();

View File

@ -330,12 +330,12 @@
- func: view_as_real(Tensor(a) self) -> Tensor(a)
variants: function
dispatch:
CPU, CUDA, MPS: view_as_real
CPU, CUDA, MPS, Meta: view_as_real
- func: view_as_complex(Tensor(a) self) -> Tensor(a)
variants: function
dispatch:
CPU, CUDA: view_as_complex
CPU, CUDA, Meta: view_as_complex
- func: sgn(Tensor self) -> Tensor
variants: function, method

View File

@ -14,6 +14,8 @@ Functions
.. autofunction:: get_overridable_functions
.. autofunction:: resolve_name
.. autofunction:: get_testing_overrides
.. autofunction:: handle_torch_function

File diff suppressed because it is too large Load Diff

View File

@ -1087,6 +1087,16 @@ class TestDisabledTorchFunction(TestCase):
self.assertEqual(torch.nn.functional.linear(inp, t1, t2), "called")
self.assertEqual(torch.nn.functional.linear(inp, t2, t1), "called")
class TestResolveName(TestCase):
def test_resolve_name(self):
for cs in get_overridable_functions().values():
for c in cs:
self.assertEqual(
eval(torch.overrides.resolve_name(c)),
c,
msg=f"{c}, {torch.overrides.resolve_name(c)}"
)
class TestTorchFunctionWarning(TestCase):
def test_warn_on_invalid_torch_function(self):
class Bad1():

View File

@ -894,8 +894,8 @@ $1 = torch._ops.aten.add.Tensor($0, $0)''')
return func(*args, **kwargs)
x = torch.randn(1)
with push_torch_dispatch_mode(partial(Logger, "A")):
with push_torch_dispatch_mode(partial(Logger, "B")):
with Logger.push("A"):
with Logger.push("B"):
x + x
self.assertEqual(logs, ["B", "A"])

View File

@ -310,11 +310,7 @@ class TestViewOps(TestCase):
res = torch.view_as_real(input)
self.assertEqual(res[:, :, 0], input.real)
self.assertEqual(res[:, :, 1], input.imag)
# TODO: Add torch.ComplexHalfStorage
if dtype != torch.complex32:
self.assertTrue(self.is_view_of(t, res))
else:
self.assertRaises(RuntimeError, lambda: self.is_view_of(t, res))
self.assertTrue(self.is_view_of(t, res))
fn()
fn(contiguous_input=False)
@ -322,21 +318,13 @@ class TestViewOps(TestCase):
# tensor with zero elements
x = torch.tensor([], dtype=dtype, device=device)
res = torch.view_as_real(x)
# TODO: Add torch.ComplexHalfStorage
if dtype != torch.complex32:
self.assertTrue(self.is_view_of(x, res))
else:
self.assertRaises(RuntimeError, lambda: self.is_view_of(x, res))
self.assertTrue(self.is_view_of(x, res))
self.assertEqual(res.shape, torch.Size([0, 2]))
# tensor with zero dim
x = torch.tensor(2 + 3j, dtype=dtype, device=device)
res = torch.view_as_real(x)
# TODO: Add torch.ComplexHalfStorage
if dtype != torch.complex32:
self.assertTrue(self.is_view_of(x, res))
else:
self.assertRaises(RuntimeError, lambda: self.is_view_of(x, res))
self.assertTrue(self.is_view_of(x, res))
self.assertEqual(res.shape, torch.Size([2]))
@onlyNativeDeviceTypes

View File

@ -841,6 +841,8 @@ class Generator(object):
# Defined in torch/csrc/utils/python_dispatch.cpp
def _dispatch_library(kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0) -> Any: ...
def _dispatch_has_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ...
def _dispatch_has_kernel(name: str) -> _bool: ...
# Defined in torch/csrc/utils/init.cpp
class BenchmarkConfig(object):

View File

@ -1,7 +1,7 @@
import torch
import torch._ops
import torch.library
from typing import Callable, Union, Dict, Sequence
from typing import Callable, Union, Dict, Sequence, List
from torch.utils._pytree import tree_map
from collections import defaultdict
@ -15,7 +15,7 @@ decomposition_table: Dict[torch._ops.OpOverload, Callable] = {}
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
def register_decomposition(aten_op, registry=None, *, register_meta: bool = False):
def register_decomposition(aten_op, registry=None, *, disable_meta: bool = False):
"""
A decorator to register a function as a decomposition to the Python
decomposition table. Use it like this::
@ -32,9 +32,9 @@ def register_decomposition(aten_op, registry=None, *, register_meta: bool = Fals
autograd) and not just backend tracing, where we then need to know if a
decomposition can be used to simulate a transform.
If `register_meta` is True, we will also register this function to the
Meta key in the dispatcher, so that it will be used to compute meta
tensors.
By default, if the decomposition is for an operator that doesn't have
a Meta implementation, we will register it to the dispatcher. Use
`disable_meta` to disable this behavior.
"""
def decomposition_decorator(f):
nonlocal registry
@ -53,7 +53,18 @@ def register_decomposition(aten_op, registry=None, *, register_meta: bool = Fals
if op_overload in registry:
raise RuntimeError(f"duplicate registrations for {op_overload}")
registry[op_overload] = f
if register_meta:
# TODO: factor this logic into OpOverload or Library API
name = op_overload._schema.name
if op_overload._schema.overload_name:
name += "." + op_overload._schema.overload_name
if (
not disable_meta
# TorchScript dumps a bunch of extra nonsense overloads
# which don't have corresponding dispatcher entries, we need
# to filter those out
and torch._C._dispatch_has_kernel(name)
and not torch._C._dispatch_has_kernel_for_dispatch_key(name, 'Meta')
):
meta_lib.impl(op_overload, f)
# To handle allowing multiple aten_ops at once

View File

@ -392,7 +392,7 @@ def mse_loss_backward(
return norm * (input - target) * grad_output
@register_decomposition(aten.huber_loss, register_meta=True)
@register_decomposition(aten.huber_loss)
@pw_cast_for_opmath
def huber_loss(
self: Tensor,
@ -1125,7 +1125,7 @@ def std_decomposition(
# Questionable decompositions
# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
# Note that this decomposition causes issues with in-place ops
@register_decomposition(aten.detach)
@register_decomposition(aten.detach, disable_meta=True)
def detach_decomposition(x):
return x

View File

@ -247,7 +247,7 @@ def _make_elementwise_unary_reference(
*,
type_promotion_kind,
aten_op=infer_aten_op,
register_meta=False,
disable_meta=False,
extra_meta=None,
) -> Callable:
@out_wrapper
@ -269,7 +269,7 @@ def _make_elementwise_unary_reference(
if aten_op is infer_aten_op:
aten_op = getattr(torch.ops.aten, prim.__name__.split(".")[0])
if aten_op is not None:
register_decomposition(aten_op, register_meta=register_meta)(_ref)
register_decomposition(aten_op, disable_meta=disable_meta)(_ref)
return _ref
@ -373,7 +373,6 @@ isnan = _make_elementwise_unary_reference(
_isnan,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
aten_op=torch.ops.aten.isnan, # prim/aten name mismatch
register_meta=True,
)
lgamma = _make_elementwise_unary_reference(
@ -456,7 +455,7 @@ def _make_elementwise_binary_reference(
has_out=True,
supports_lhs_python_scalar=True,
supports_rhs_python_scalar=True,
register_meta=False,
disable_meta=False,
) -> Callable:
@elementwise_type_promotion_wrapper(
type_promoting_args=("a", "b"),
@ -491,7 +490,7 @@ def _make_elementwise_binary_reference(
if aten_op is infer_aten_op:
aten_op = getattr(torch.ops.aten, prim.__name__.split(".")[0])
if aten_op is not None:
register_decomposition(aten_op, register_meta=register_meta)(_ref)
register_decomposition(aten_op, disable_meta=disable_meta)(_ref)
return _ref
@ -717,7 +716,6 @@ logical_and = _make_elementwise_binary_reference(
_logical_and,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
aten_op=torch.ops.aten.logical_and,
register_meta=True,
)
@ -733,7 +731,6 @@ logical_or = _make_elementwise_binary_reference(
_logical_or,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
aten_op=torch.ops.aten.logical_or,
register_meta=True,
)
# TODO: add docstring
@ -836,7 +833,7 @@ true_divide = _make_elementwise_binary_reference(
# https://pytorch.org/docs/stable/generated/torch.where.html
# TODO: implement alternate where
@register_decomposition(torch.ops.aten.where, register_meta=True)
@register_decomposition(torch.ops.aten.where)
@out_wrapper
@elementwise_type_promotion_wrapper(
type_promoting_args=("a", "b"),
@ -1090,7 +1087,7 @@ def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorL
return prims.collapse(a, start_dim, end_dim + 1)
@register_decomposition(torch.ops.aten.flip, register_meta=True)
@register_decomposition(torch.ops.aten.flip)
def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
if not isinstance(dims, tuple) and not isinstance(dims, list):
raise ValueError("dims has to be a sequence of ints")

View File

@ -202,9 +202,6 @@ class Tensor(torch._C._TensorBase):
if has_torch_function_unary(self):
return handle_torch_function(Tensor.storage, (self,), self)
if self.dtype not in torch.storage._dtype_to_storage_type_map():
raise RuntimeError(f'unsupported Storage type: {self.dtype}')
return torch._TypedStorage(wrap_storage=self._storage(), dtype=self.dtype)
def _reduce_ex_internal(self, proto):

View File

@ -202,6 +202,17 @@ void initDispatchBindings(PyObject* module) {
c10::Dispatcher::singleton().checkInvariants();
});
m.def("_dispatch_has_kernel", [](const char* name) -> bool {
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
return static_cast<bool>(op);
});
m.def("_dispatch_has_kernel_for_dispatch_key", [](const char* name, const char* dispatch) -> bool {
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
TORCH_CHECK(op, "operator ", name, " does not exist");
return op->hasKernelForDispatchKey(c10::parseDispatchKey(dispatch));
});
m.def("_dispatch_find_dangling_impls", []() -> std::vector<std::string> {
auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls();

View File

@ -26,7 +26,7 @@ import collections
import functools
import types
import warnings
from typing import Dict, Set, List, Any, Callable, Iterable, Type, Iterator
from typing import Dict, Set, List, Any, Callable, Iterable, Type, Iterator, Tuple
import contextlib
import torch
@ -42,6 +42,7 @@ __all__ = [
"get_testing_overrides",
"handle_torch_function",
"has_torch_function",
"resolve_name",
"is_tensor_like",
"is_tensor_method_or_property",
"wrap_torch_function",
@ -1556,36 +1557,32 @@ has_torch_function_variadic = _add_docstr(
)
@functools.lru_cache(None)
def get_overridable_functions() -> Dict[Any, List[Callable]]:
"""List functions that are overridable via __torch_function__
Returns
-------
Dict[Any, List[Callable]]
A dictionary that maps namespaces that contain overridable functions
to functions in that namespace that can be overridden.
"""
def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]:
overridable_funcs = collections.defaultdict(list)
index = {}
tested_namespaces = [
(torch, torch.__all__ + dir(torch._C._VariableFunctions)),
(torch.functional, torch.functional.__all__),
(torch.nn.functional, dir(torch.nn.functional)),
(torch.nn.init, dir(torch.nn.init)),
(torch.Tensor, dir(torch.Tensor)),
(torch.linalg, dir(torch.linalg)),
(torch.fft, dir(torch.fft)),
(torch.special, dir(torch.special)),
("torch", torch, torch.__all__ + dir(torch._C._VariableFunctions)),
("torch.functional", torch.functional, torch.functional.__all__),
("torch.nn.functional", torch.nn.functional, dir(torch.nn.functional)),
("torch.nn.init", torch.nn.init, dir(torch.nn.init)),
("torch.Tensor", torch.Tensor, dir(torch.Tensor)),
("torch.linalg", torch.linalg, dir(torch.linalg)),
("torch.fft", torch.fft, dir(torch.fft)),
("torch.special", torch.special, dir(torch.special)),
]
for namespace, ns_funcs in tested_namespaces:
for namespace_str, namespace, ns_funcs in tested_namespaces:
for func_name in ns_funcs:
ignore = False
# ignore private functions or functions that are deleted in torch.__init__
if namespace is not torch.Tensor:
if func_name.startswith('_'):
if func_name.startswith('__'):
continue
elif func_name.startswith('_'):
ignore = True
elif func_name.endswith('_'):
continue
ignore = True
elif not func_name[0].islower():
continue
ignore = True
elif func_name == 'unique_dim':
continue
else:
@ -1605,6 +1602,10 @@ def get_overridable_functions() -> Dict[Any, List[Callable]]:
continue
if not callable(func) and hasattr(func, "__get__"):
index[func.__get__] = f"{namespace_str}.{func_name}.__get__"
index[func.__set__] = f"{namespace_str}.{func_name}.__set__"
if ignore:
continue
if func.__get__ in get_ignored_functions():
msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
"but still has an explicit override")
@ -1617,6 +1618,11 @@ def get_overridable_functions() -> Dict[Any, List[Callable]]:
if not callable(func):
continue
index[func] = f"{namespace_str}.{func_name}"
if ignore:
continue
# cannot be overriden by __torch_function__
if func in get_ignored_functions():
msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
@ -1624,7 +1630,37 @@ def get_overridable_functions() -> Dict[Any, List[Callable]]:
assert func not in get_testing_overrides(), msg.format(namespace, func.__name__)
continue
overridable_funcs[namespace].append(func)
return overridable_funcs
return overridable_funcs, index
def get_overridable_functions() -> Dict[Any, List[Callable]]:
"""List functions that are overridable via __torch_function__
Returns
-------
Dict[Any, List[Callable]]
A dictionary that maps namespaces that contain overridable functions
to functions in that namespace that can be overridden.
"""
return _get_overridable_functions()[0]
def resolve_name(f):
"""Get a human readable string name for a function passed to
__torch_function__
Arguments
---------
callable : Callable
Function to resolve the name of.
Returns
-------
str
Name of the function; if eval'ed it should give back the input
function.
"""
if isinstance(f, torch._ops.OpOverload):
return str(f)
return _get_overridable_functions()[1].get(f)
@functools.lru_cache(None)
def _get_tensor_methods() -> Set[Callable]:
@ -1782,6 +1818,10 @@ class TorchFunctionMode(metaclass=TorchFunctionModeMeta):
def __torch_function__(self, func, types, args=(), kwargs=None):
raise NotImplementedError()
@classmethod
def push(cls, *args, **kwargs):
return push_torch_function_mode(functools.partial(cls, *args, **kwargs))
class BaseTorchFunctionMode(TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):

View File

@ -10026,7 +10026,6 @@ op_db: List[OpInfo] = [
# Reference: https://github.com/pytorch/pytorch/issues/50747
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16)),
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', dtypes=(torch.bool,)),
),
sample_inputs_func=sample_inputs_addr,
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
@ -11441,8 +11440,6 @@ op_db: List[OpInfo] = [
skips=(
# Skip since real and imag don't have out variants.
DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta',
dtypes=(torch.complex32,)),
)),
OpInfo('gradient',
dtypes=floating_and_complex_types_and(torch.int8, torch.int16,
@ -12036,6 +12033,7 @@ op_db: List[OpInfo] = [
dtypes=floating_and_complex_types(),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(skipCPUIfNoLapack,),
sample_inputs_func=sample_inputs_lu_unpack),
OpInfo('lu',
op=torch.lu,
@ -12069,7 +12067,7 @@ op_db: List[OpInfo] = [
# See https://github.com/pytorch/pytorch/issues/66357
check_batched_forward_grad=False,
sample_inputs_func=sample_inputs_lu_solve,
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
skips=(
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
device_type='mps', dtypes=[torch.float32]),
@ -12595,7 +12593,6 @@ op_db: List[OpInfo] = [
skips=(
# AssertionError: Resizing an out= argument with no elements threw a resize warning!
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'),
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cpu'),
)),
OpInfo('as_strided',
op=lambda x, size, stride, storage_offset=0:
@ -13334,9 +13331,6 @@ op_db: List[OpInfo] = [
skips=(
# Pre-existing condition; Needs to be fixed
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cpu'),
# RuntimeError: "max_pool1d_impl" not implemented for 'BFloat16'
DecorateInfo(unittest.skip("Works on some configs"), 'TestMeta',
'test_meta', dtypes=(torch.bfloat16,)),
DecorateInfo(unittest.skip("Works on some configs"), 'TestNNCOpInfo',
'test_nnc_correctness', dtypes=(torch.bfloat16,)),
DecorateInfo(unittest.skip("Works on some conifgs"), 'TestCudaFuserOpInfo',
@ -14382,8 +14376,6 @@ op_db: List[OpInfo] = [
skips=(
# Skip since real and imag don't have out variants.
DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta',
dtypes=(torch.complex32,)),
)),
OpInfo('roll',
ref=np.roll,
@ -14896,7 +14888,7 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
gradcheck_wrapper=lambda *args, **kwargs: gradcheck_wrapper_triangular_input(*args, idx=1, **kwargs),
decorators=[skipCUDAIfNoMagma],
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
skips=(
# AssertionError: Scalars are not equal!
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
@ -15119,12 +15111,7 @@ op_db: List[OpInfo] = [
ref=np.isfinite,
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
supports_out=False,
supports_autograd=False,
skips=(
# NotImplementedError:
# Could not run 'aten::view_as_real' with arguments from the 'Meta' backend.
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta", dtypes=(torch.chalf,)),
)),
supports_autograd=False),
UnaryUfuncInfo('isinf',
ref=np.isinf,
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
@ -15133,8 +15120,6 @@ op_db: List[OpInfo] = [
supports_sparse_csr=True,
supports_autograd=False,
skips=(
# Could not run 'aten::view_as_real' with arguments from the 'Meta' backend.
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta", dtypes=(torch.chalf,)),
# "nonzero_count_cpu" not implemented for 'ComplexHalf'
# "nonzero_cuda" not implemented for 'ComplexHalf'
DecorateInfo(unittest.expectedFailure, "TestSparseCSR",
@ -15160,12 +15145,7 @@ op_db: List[OpInfo] = [
ref=np.isreal,
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
supports_out=False,
supports_autograd=False,
skips=(
# NotImplementedError:
# Could not run 'aten::view_as_real' with arguments from the 'Meta' backend.
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta", dtypes=(torch.chalf,)),
)),
supports_autograd=False),
UnaryUfuncInfo('isnan',
ref=np.isnan,
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
@ -15198,6 +15178,7 @@ op_db: List[OpInfo] = [
dtypes=floating_and_complex_types(),
sample_inputs_func=sample_inputs_linalg_solve_triangular,
supports_fwgrad_bwgrad=True,
skips=(skipCPUIfNoLapack,),
# linalg.solve_triangular cannot be batched over because of a call to out.copy_(result);
supports_forward_ad=True),
OpInfo('linalg.matrix_rank',
@ -15242,6 +15223,7 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_linalg_pinv,
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
skips=(
# errors with "leaked XXXX bytes CUDA memory on device 0"
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),)
@ -15344,9 +15326,6 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', dtypes=(torch.complex128,)),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
# stride mismatch
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cuda',
dtypes=(torch.float32, torch.float64), active_if=not TEST_WITH_ROCM),
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
device_type='mps', dtypes=[torch.float32]),
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
@ -15371,9 +15350,6 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', dtypes=(torch.complex128,)),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
# stride mismatch
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cuda',
dtypes=(torch.float32, torch.float64), active_if=not TEST_WITH_ROCM),
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
device_type='mps', dtypes=[torch.float32]),
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
@ -15418,8 +15394,6 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
# stride mismatch
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cuda', active_if=not TEST_WITH_ROCM),
)),
OpInfo('pca_lowrank',
op=lambda *args, **kwargs: wrapper_set_seed(
@ -15444,8 +15418,6 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
# stride mismatch
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cuda', active_if=not TEST_WITH_ROCM),
)),
BinaryUfuncInfo('polar',
dtypes=floating_types(),
@ -17423,7 +17395,7 @@ op_db: List[OpInfo] = [
sample_inputs_func=sample_inputs_tensorsolve,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver],
decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagma],
),
OpInfo(
"nn.functional.mse_loss",

View File

@ -147,6 +147,10 @@ class TorchDispatchMode(metaclass=TorchDispatchModeMeta):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise NotImplementedError()
@classmethod
def push(cls, *args, **kwargs):
return push_torch_dispatch_mode(functools.partial(cls, *args, **kwargs))
class BaseTorchDispatchMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):