mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "Revert "Implement sym_sizes to create proper IR for sym ints representing tensor sizes (#76836)""
This reverts commit c35bd8d423ca53408c3aa39c2280167f3a22cea0. Pull Request resolved: https://github.com/pytorch/pytorch/pull/77719 Approved by: https://github.com/Chillee, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
befa4e371e
commit
4941e72e40
@ -32,13 +32,13 @@ void max_pool1d_impl(
|
||||
Tensor& output,
|
||||
const Tensor& input,
|
||||
const PoolingParams1D& p) {
|
||||
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_pool1d_impl", [&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool1d_impl", [&] {
|
||||
const Tensor in = input.contiguous();
|
||||
scalar_t* const OP = output.data_ptr<scalar_t>();
|
||||
const scalar_t* const IP = in.data_ptr<scalar_t>();
|
||||
|
||||
// Value used for padding
|
||||
constexpr scalar_t FILL = std::numeric_limits<scalar_t>::has_infinity
|
||||
scalar_t FILL = std::numeric_limits<scalar_t>::has_infinity
|
||||
? -std::numeric_limits<scalar_t>::infinity()
|
||||
: std::numeric_limits<scalar_t>::lowest();
|
||||
|
||||
|
@ -330,12 +330,12 @@
|
||||
- func: view_as_real(Tensor(a) self) -> Tensor(a)
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU, CUDA, MPS: view_as_real
|
||||
CPU, CUDA, MPS, Meta: view_as_real
|
||||
|
||||
- func: view_as_complex(Tensor(a) self) -> Tensor(a)
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU, CUDA: view_as_complex
|
||||
CPU, CUDA, Meta: view_as_complex
|
||||
|
||||
- func: sgn(Tensor self) -> Tensor
|
||||
variants: function, method
|
||||
|
@ -14,6 +14,8 @@ Functions
|
||||
|
||||
.. autofunction:: get_overridable_functions
|
||||
|
||||
.. autofunction:: resolve_name
|
||||
|
||||
.. autofunction:: get_testing_overrides
|
||||
|
||||
.. autofunction:: handle_torch_function
|
||||
|
1837
test/test_meta.py
1837
test/test_meta.py
File diff suppressed because it is too large
Load Diff
@ -1087,6 +1087,16 @@ class TestDisabledTorchFunction(TestCase):
|
||||
self.assertEqual(torch.nn.functional.linear(inp, t1, t2), "called")
|
||||
self.assertEqual(torch.nn.functional.linear(inp, t2, t1), "called")
|
||||
|
||||
class TestResolveName(TestCase):
|
||||
def test_resolve_name(self):
|
||||
for cs in get_overridable_functions().values():
|
||||
for c in cs:
|
||||
self.assertEqual(
|
||||
eval(torch.overrides.resolve_name(c)),
|
||||
c,
|
||||
msg=f"{c}, {torch.overrides.resolve_name(c)}"
|
||||
)
|
||||
|
||||
class TestTorchFunctionWarning(TestCase):
|
||||
def test_warn_on_invalid_torch_function(self):
|
||||
class Bad1():
|
||||
|
@ -894,8 +894,8 @@ $1 = torch._ops.aten.add.Tensor($0, $0)''')
|
||||
return func(*args, **kwargs)
|
||||
|
||||
x = torch.randn(1)
|
||||
with push_torch_dispatch_mode(partial(Logger, "A")):
|
||||
with push_torch_dispatch_mode(partial(Logger, "B")):
|
||||
with Logger.push("A"):
|
||||
with Logger.push("B"):
|
||||
x + x
|
||||
self.assertEqual(logs, ["B", "A"])
|
||||
|
||||
|
@ -310,11 +310,7 @@ class TestViewOps(TestCase):
|
||||
res = torch.view_as_real(input)
|
||||
self.assertEqual(res[:, :, 0], input.real)
|
||||
self.assertEqual(res[:, :, 1], input.imag)
|
||||
# TODO: Add torch.ComplexHalfStorage
|
||||
if dtype != torch.complex32:
|
||||
self.assertTrue(self.is_view_of(t, res))
|
||||
else:
|
||||
self.assertRaises(RuntimeError, lambda: self.is_view_of(t, res))
|
||||
self.assertTrue(self.is_view_of(t, res))
|
||||
|
||||
fn()
|
||||
fn(contiguous_input=False)
|
||||
@ -322,21 +318,13 @@ class TestViewOps(TestCase):
|
||||
# tensor with zero elements
|
||||
x = torch.tensor([], dtype=dtype, device=device)
|
||||
res = torch.view_as_real(x)
|
||||
# TODO: Add torch.ComplexHalfStorage
|
||||
if dtype != torch.complex32:
|
||||
self.assertTrue(self.is_view_of(x, res))
|
||||
else:
|
||||
self.assertRaises(RuntimeError, lambda: self.is_view_of(x, res))
|
||||
self.assertTrue(self.is_view_of(x, res))
|
||||
self.assertEqual(res.shape, torch.Size([0, 2]))
|
||||
|
||||
# tensor with zero dim
|
||||
x = torch.tensor(2 + 3j, dtype=dtype, device=device)
|
||||
res = torch.view_as_real(x)
|
||||
# TODO: Add torch.ComplexHalfStorage
|
||||
if dtype != torch.complex32:
|
||||
self.assertTrue(self.is_view_of(x, res))
|
||||
else:
|
||||
self.assertRaises(RuntimeError, lambda: self.is_view_of(x, res))
|
||||
self.assertTrue(self.is_view_of(x, res))
|
||||
self.assertEqual(res.shape, torch.Size([2]))
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
|
@ -841,6 +841,8 @@ class Generator(object):
|
||||
|
||||
# Defined in torch/csrc/utils/python_dispatch.cpp
|
||||
def _dispatch_library(kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0) -> Any: ...
|
||||
def _dispatch_has_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ...
|
||||
def _dispatch_has_kernel(name: str) -> _bool: ...
|
||||
|
||||
# Defined in torch/csrc/utils/init.cpp
|
||||
class BenchmarkConfig(object):
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch._ops
|
||||
import torch.library
|
||||
from typing import Callable, Union, Dict, Sequence
|
||||
from typing import Callable, Union, Dict, Sequence, List
|
||||
from torch.utils._pytree import tree_map
|
||||
from collections import defaultdict
|
||||
|
||||
@ -15,7 +15,7 @@ decomposition_table: Dict[torch._ops.OpOverload, Callable] = {}
|
||||
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
|
||||
|
||||
|
||||
def register_decomposition(aten_op, registry=None, *, register_meta: bool = False):
|
||||
def register_decomposition(aten_op, registry=None, *, disable_meta: bool = False):
|
||||
"""
|
||||
A decorator to register a function as a decomposition to the Python
|
||||
decomposition table. Use it like this::
|
||||
@ -32,9 +32,9 @@ def register_decomposition(aten_op, registry=None, *, register_meta: bool = Fals
|
||||
autograd) and not just backend tracing, where we then need to know if a
|
||||
decomposition can be used to simulate a transform.
|
||||
|
||||
If `register_meta` is True, we will also register this function to the
|
||||
Meta key in the dispatcher, so that it will be used to compute meta
|
||||
tensors.
|
||||
By default, if the decomposition is for an operator that doesn't have
|
||||
a Meta implementation, we will register it to the dispatcher. Use
|
||||
`disable_meta` to disable this behavior.
|
||||
"""
|
||||
def decomposition_decorator(f):
|
||||
nonlocal registry
|
||||
@ -53,7 +53,18 @@ def register_decomposition(aten_op, registry=None, *, register_meta: bool = Fals
|
||||
if op_overload in registry:
|
||||
raise RuntimeError(f"duplicate registrations for {op_overload}")
|
||||
registry[op_overload] = f
|
||||
if register_meta:
|
||||
# TODO: factor this logic into OpOverload or Library API
|
||||
name = op_overload._schema.name
|
||||
if op_overload._schema.overload_name:
|
||||
name += "." + op_overload._schema.overload_name
|
||||
if (
|
||||
not disable_meta
|
||||
# TorchScript dumps a bunch of extra nonsense overloads
|
||||
# which don't have corresponding dispatcher entries, we need
|
||||
# to filter those out
|
||||
and torch._C._dispatch_has_kernel(name)
|
||||
and not torch._C._dispatch_has_kernel_for_dispatch_key(name, 'Meta')
|
||||
):
|
||||
meta_lib.impl(op_overload, f)
|
||||
|
||||
# To handle allowing multiple aten_ops at once
|
||||
|
@ -392,7 +392,7 @@ def mse_loss_backward(
|
||||
return norm * (input - target) * grad_output
|
||||
|
||||
|
||||
@register_decomposition(aten.huber_loss, register_meta=True)
|
||||
@register_decomposition(aten.huber_loss)
|
||||
@pw_cast_for_opmath
|
||||
def huber_loss(
|
||||
self: Tensor,
|
||||
@ -1125,7 +1125,7 @@ def std_decomposition(
|
||||
# Questionable decompositions
|
||||
# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
|
||||
# Note that this decomposition causes issues with in-place ops
|
||||
@register_decomposition(aten.detach)
|
||||
@register_decomposition(aten.detach, disable_meta=True)
|
||||
def detach_decomposition(x):
|
||||
return x
|
||||
|
||||
|
@ -247,7 +247,7 @@ def _make_elementwise_unary_reference(
|
||||
*,
|
||||
type_promotion_kind,
|
||||
aten_op=infer_aten_op,
|
||||
register_meta=False,
|
||||
disable_meta=False,
|
||||
extra_meta=None,
|
||||
) -> Callable:
|
||||
@out_wrapper
|
||||
@ -269,7 +269,7 @@ def _make_elementwise_unary_reference(
|
||||
if aten_op is infer_aten_op:
|
||||
aten_op = getattr(torch.ops.aten, prim.__name__.split(".")[0])
|
||||
if aten_op is not None:
|
||||
register_decomposition(aten_op, register_meta=register_meta)(_ref)
|
||||
register_decomposition(aten_op, disable_meta=disable_meta)(_ref)
|
||||
|
||||
return _ref
|
||||
|
||||
@ -373,7 +373,6 @@ isnan = _make_elementwise_unary_reference(
|
||||
_isnan,
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
|
||||
aten_op=torch.ops.aten.isnan, # prim/aten name mismatch
|
||||
register_meta=True,
|
||||
)
|
||||
|
||||
lgamma = _make_elementwise_unary_reference(
|
||||
@ -456,7 +455,7 @@ def _make_elementwise_binary_reference(
|
||||
has_out=True,
|
||||
supports_lhs_python_scalar=True,
|
||||
supports_rhs_python_scalar=True,
|
||||
register_meta=False,
|
||||
disable_meta=False,
|
||||
) -> Callable:
|
||||
@elementwise_type_promotion_wrapper(
|
||||
type_promoting_args=("a", "b"),
|
||||
@ -491,7 +490,7 @@ def _make_elementwise_binary_reference(
|
||||
if aten_op is infer_aten_op:
|
||||
aten_op = getattr(torch.ops.aten, prim.__name__.split(".")[0])
|
||||
if aten_op is not None:
|
||||
register_decomposition(aten_op, register_meta=register_meta)(_ref)
|
||||
register_decomposition(aten_op, disable_meta=disable_meta)(_ref)
|
||||
|
||||
return _ref
|
||||
|
||||
@ -717,7 +716,6 @@ logical_and = _make_elementwise_binary_reference(
|
||||
_logical_and,
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
|
||||
aten_op=torch.ops.aten.logical_and,
|
||||
register_meta=True,
|
||||
)
|
||||
|
||||
|
||||
@ -733,7 +731,6 @@ logical_or = _make_elementwise_binary_reference(
|
||||
_logical_or,
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
|
||||
aten_op=torch.ops.aten.logical_or,
|
||||
register_meta=True,
|
||||
)
|
||||
|
||||
# TODO: add docstring
|
||||
@ -836,7 +833,7 @@ true_divide = _make_elementwise_binary_reference(
|
||||
|
||||
# https://pytorch.org/docs/stable/generated/torch.where.html
|
||||
# TODO: implement alternate where
|
||||
@register_decomposition(torch.ops.aten.where, register_meta=True)
|
||||
@register_decomposition(torch.ops.aten.where)
|
||||
@out_wrapper
|
||||
@elementwise_type_promotion_wrapper(
|
||||
type_promoting_args=("a", "b"),
|
||||
@ -1090,7 +1087,7 @@ def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorL
|
||||
return prims.collapse(a, start_dim, end_dim + 1)
|
||||
|
||||
|
||||
@register_decomposition(torch.ops.aten.flip, register_meta=True)
|
||||
@register_decomposition(torch.ops.aten.flip)
|
||||
def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
|
||||
if not isinstance(dims, tuple) and not isinstance(dims, list):
|
||||
raise ValueError("dims has to be a sequence of ints")
|
||||
|
@ -202,9 +202,6 @@ class Tensor(torch._C._TensorBase):
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(Tensor.storage, (self,), self)
|
||||
|
||||
if self.dtype not in torch.storage._dtype_to_storage_type_map():
|
||||
raise RuntimeError(f'unsupported Storage type: {self.dtype}')
|
||||
|
||||
return torch._TypedStorage(wrap_storage=self._storage(), dtype=self.dtype)
|
||||
|
||||
def _reduce_ex_internal(self, proto):
|
||||
|
@ -202,6 +202,17 @@ void initDispatchBindings(PyObject* module) {
|
||||
c10::Dispatcher::singleton().checkInvariants();
|
||||
});
|
||||
|
||||
m.def("_dispatch_has_kernel", [](const char* name) -> bool {
|
||||
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
||||
return static_cast<bool>(op);
|
||||
});
|
||||
|
||||
m.def("_dispatch_has_kernel_for_dispatch_key", [](const char* name, const char* dispatch) -> bool {
|
||||
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
||||
TORCH_CHECK(op, "operator ", name, " does not exist");
|
||||
return op->hasKernelForDispatchKey(c10::parseDispatchKey(dispatch));
|
||||
});
|
||||
|
||||
m.def("_dispatch_find_dangling_impls", []() -> std::vector<std::string> {
|
||||
auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls();
|
||||
|
||||
|
@ -26,7 +26,7 @@ import collections
|
||||
import functools
|
||||
import types
|
||||
import warnings
|
||||
from typing import Dict, Set, List, Any, Callable, Iterable, Type, Iterator
|
||||
from typing import Dict, Set, List, Any, Callable, Iterable, Type, Iterator, Tuple
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
@ -42,6 +42,7 @@ __all__ = [
|
||||
"get_testing_overrides",
|
||||
"handle_torch_function",
|
||||
"has_torch_function",
|
||||
"resolve_name",
|
||||
"is_tensor_like",
|
||||
"is_tensor_method_or_property",
|
||||
"wrap_torch_function",
|
||||
@ -1556,36 +1557,32 @@ has_torch_function_variadic = _add_docstr(
|
||||
)
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_overridable_functions() -> Dict[Any, List[Callable]]:
|
||||
"""List functions that are overridable via __torch_function__
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[Any, List[Callable]]
|
||||
A dictionary that maps namespaces that contain overridable functions
|
||||
to functions in that namespace that can be overridden.
|
||||
"""
|
||||
def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]:
|
||||
overridable_funcs = collections.defaultdict(list)
|
||||
index = {}
|
||||
tested_namespaces = [
|
||||
(torch, torch.__all__ + dir(torch._C._VariableFunctions)),
|
||||
(torch.functional, torch.functional.__all__),
|
||||
(torch.nn.functional, dir(torch.nn.functional)),
|
||||
(torch.nn.init, dir(torch.nn.init)),
|
||||
(torch.Tensor, dir(torch.Tensor)),
|
||||
(torch.linalg, dir(torch.linalg)),
|
||||
(torch.fft, dir(torch.fft)),
|
||||
(torch.special, dir(torch.special)),
|
||||
("torch", torch, torch.__all__ + dir(torch._C._VariableFunctions)),
|
||||
("torch.functional", torch.functional, torch.functional.__all__),
|
||||
("torch.nn.functional", torch.nn.functional, dir(torch.nn.functional)),
|
||||
("torch.nn.init", torch.nn.init, dir(torch.nn.init)),
|
||||
("torch.Tensor", torch.Tensor, dir(torch.Tensor)),
|
||||
("torch.linalg", torch.linalg, dir(torch.linalg)),
|
||||
("torch.fft", torch.fft, dir(torch.fft)),
|
||||
("torch.special", torch.special, dir(torch.special)),
|
||||
]
|
||||
for namespace, ns_funcs in tested_namespaces:
|
||||
for namespace_str, namespace, ns_funcs in tested_namespaces:
|
||||
for func_name in ns_funcs:
|
||||
ignore = False
|
||||
# ignore private functions or functions that are deleted in torch.__init__
|
||||
if namespace is not torch.Tensor:
|
||||
if func_name.startswith('_'):
|
||||
if func_name.startswith('__'):
|
||||
continue
|
||||
elif func_name.startswith('_'):
|
||||
ignore = True
|
||||
elif func_name.endswith('_'):
|
||||
continue
|
||||
ignore = True
|
||||
elif not func_name[0].islower():
|
||||
continue
|
||||
ignore = True
|
||||
elif func_name == 'unique_dim':
|
||||
continue
|
||||
else:
|
||||
@ -1605,6 +1602,10 @@ def get_overridable_functions() -> Dict[Any, List[Callable]]:
|
||||
continue
|
||||
|
||||
if not callable(func) and hasattr(func, "__get__"):
|
||||
index[func.__get__] = f"{namespace_str}.{func_name}.__get__"
|
||||
index[func.__set__] = f"{namespace_str}.{func_name}.__set__"
|
||||
if ignore:
|
||||
continue
|
||||
if func.__get__ in get_ignored_functions():
|
||||
msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
|
||||
"but still has an explicit override")
|
||||
@ -1617,6 +1618,11 @@ def get_overridable_functions() -> Dict[Any, List[Callable]]:
|
||||
if not callable(func):
|
||||
continue
|
||||
|
||||
index[func] = f"{namespace_str}.{func_name}"
|
||||
|
||||
if ignore:
|
||||
continue
|
||||
|
||||
# cannot be overriden by __torch_function__
|
||||
if func in get_ignored_functions():
|
||||
msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
|
||||
@ -1624,7 +1630,37 @@ def get_overridable_functions() -> Dict[Any, List[Callable]]:
|
||||
assert func not in get_testing_overrides(), msg.format(namespace, func.__name__)
|
||||
continue
|
||||
overridable_funcs[namespace].append(func)
|
||||
return overridable_funcs
|
||||
return overridable_funcs, index
|
||||
|
||||
def get_overridable_functions() -> Dict[Any, List[Callable]]:
|
||||
"""List functions that are overridable via __torch_function__
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[Any, List[Callable]]
|
||||
A dictionary that maps namespaces that contain overridable functions
|
||||
to functions in that namespace that can be overridden.
|
||||
"""
|
||||
return _get_overridable_functions()[0]
|
||||
|
||||
def resolve_name(f):
|
||||
"""Get a human readable string name for a function passed to
|
||||
__torch_function__
|
||||
|
||||
Arguments
|
||||
---------
|
||||
callable : Callable
|
||||
Function to resolve the name of.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Name of the function; if eval'ed it should give back the input
|
||||
function.
|
||||
"""
|
||||
if isinstance(f, torch._ops.OpOverload):
|
||||
return str(f)
|
||||
return _get_overridable_functions()[1].get(f)
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _get_tensor_methods() -> Set[Callable]:
|
||||
@ -1782,6 +1818,10 @@ class TorchFunctionMode(metaclass=TorchFunctionModeMeta):
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def push(cls, *args, **kwargs):
|
||||
return push_torch_function_mode(functools.partial(cls, *args, **kwargs))
|
||||
|
||||
|
||||
class BaseTorchFunctionMode(TorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
|
@ -10026,7 +10026,6 @@ op_db: List[OpInfo] = [
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/50747
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16)),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', dtypes=(torch.bool,)),
|
||||
),
|
||||
sample_inputs_func=sample_inputs_addr,
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
|
||||
@ -11441,8 +11440,6 @@ op_db: List[OpInfo] = [
|
||||
skips=(
|
||||
# Skip since real and imag don't have out variants.
|
||||
DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta',
|
||||
dtypes=(torch.complex32,)),
|
||||
)),
|
||||
OpInfo('gradient',
|
||||
dtypes=floating_and_complex_types_and(torch.int8, torch.int16,
|
||||
@ -12036,6 +12033,7 @@ op_db: List[OpInfo] = [
|
||||
dtypes=floating_and_complex_types(),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
skips=(skipCPUIfNoLapack,),
|
||||
sample_inputs_func=sample_inputs_lu_unpack),
|
||||
OpInfo('lu',
|
||||
op=torch.lu,
|
||||
@ -12069,7 +12067,7 @@ op_db: List[OpInfo] = [
|
||||
# See https://github.com/pytorch/pytorch/issues/66357
|
||||
check_batched_forward_grad=False,
|
||||
sample_inputs_func=sample_inputs_lu_solve,
|
||||
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
|
||||
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
|
||||
skips=(
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
|
||||
device_type='mps', dtypes=[torch.float32]),
|
||||
@ -12595,7 +12593,6 @@ op_db: List[OpInfo] = [
|
||||
skips=(
|
||||
# AssertionError: Resizing an out= argument with no elements threw a resize warning!
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cpu'),
|
||||
)),
|
||||
OpInfo('as_strided',
|
||||
op=lambda x, size, stride, storage_offset=0:
|
||||
@ -13334,9 +13331,6 @@ op_db: List[OpInfo] = [
|
||||
skips=(
|
||||
# Pre-existing condition; Needs to be fixed
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cpu'),
|
||||
# RuntimeError: "max_pool1d_impl" not implemented for 'BFloat16'
|
||||
DecorateInfo(unittest.skip("Works on some configs"), 'TestMeta',
|
||||
'test_meta', dtypes=(torch.bfloat16,)),
|
||||
DecorateInfo(unittest.skip("Works on some configs"), 'TestNNCOpInfo',
|
||||
'test_nnc_correctness', dtypes=(torch.bfloat16,)),
|
||||
DecorateInfo(unittest.skip("Works on some conifgs"), 'TestCudaFuserOpInfo',
|
||||
@ -14382,8 +14376,6 @@ op_db: List[OpInfo] = [
|
||||
skips=(
|
||||
# Skip since real and imag don't have out variants.
|
||||
DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta',
|
||||
dtypes=(torch.complex32,)),
|
||||
)),
|
||||
OpInfo('roll',
|
||||
ref=np.roll,
|
||||
@ -14896,7 +14888,7 @@ op_db: List[OpInfo] = [
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
gradcheck_wrapper=lambda *args, **kwargs: gradcheck_wrapper_triangular_input(*args, idx=1, **kwargs),
|
||||
decorators=[skipCUDAIfNoMagma],
|
||||
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
|
||||
skips=(
|
||||
# AssertionError: Scalars are not equal!
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
|
||||
@ -15119,12 +15111,7 @@ op_db: List[OpInfo] = [
|
||||
ref=np.isfinite,
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
|
||||
supports_out=False,
|
||||
supports_autograd=False,
|
||||
skips=(
|
||||
# NotImplementedError:
|
||||
# Could not run 'aten::view_as_real' with arguments from the 'Meta' backend.
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta", dtypes=(torch.chalf,)),
|
||||
)),
|
||||
supports_autograd=False),
|
||||
UnaryUfuncInfo('isinf',
|
||||
ref=np.isinf,
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
|
||||
@ -15133,8 +15120,6 @@ op_db: List[OpInfo] = [
|
||||
supports_sparse_csr=True,
|
||||
supports_autograd=False,
|
||||
skips=(
|
||||
# Could not run 'aten::view_as_real' with arguments from the 'Meta' backend.
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta", dtypes=(torch.chalf,)),
|
||||
# "nonzero_count_cpu" not implemented for 'ComplexHalf'
|
||||
# "nonzero_cuda" not implemented for 'ComplexHalf'
|
||||
DecorateInfo(unittest.expectedFailure, "TestSparseCSR",
|
||||
@ -15160,12 +15145,7 @@ op_db: List[OpInfo] = [
|
||||
ref=np.isreal,
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
|
||||
supports_out=False,
|
||||
supports_autograd=False,
|
||||
skips=(
|
||||
# NotImplementedError:
|
||||
# Could not run 'aten::view_as_real' with arguments from the 'Meta' backend.
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta", dtypes=(torch.chalf,)),
|
||||
)),
|
||||
supports_autograd=False),
|
||||
UnaryUfuncInfo('isnan',
|
||||
ref=np.isnan,
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
|
||||
@ -15198,6 +15178,7 @@ op_db: List[OpInfo] = [
|
||||
dtypes=floating_and_complex_types(),
|
||||
sample_inputs_func=sample_inputs_linalg_solve_triangular,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
skips=(skipCPUIfNoLapack,),
|
||||
# linalg.solve_triangular cannot be batched over because of a call to out.copy_(result);
|
||||
supports_forward_ad=True),
|
||||
OpInfo('linalg.matrix_rank',
|
||||
@ -15242,6 +15223,7 @@ op_db: List[OpInfo] = [
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
sample_inputs_func=sample_inputs_linalg_pinv,
|
||||
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
|
||||
skips=(
|
||||
# errors with "leaked XXXX bytes CUDA memory on device 0"
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),)
|
||||
@ -15344,9 +15326,6 @@ op_db: List[OpInfo] = [
|
||||
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', dtypes=(torch.complex128,)),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
|
||||
# stride mismatch
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cuda',
|
||||
dtypes=(torch.float32, torch.float64), active_if=not TEST_WITH_ROCM),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
|
||||
device_type='mps', dtypes=[torch.float32]),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
|
||||
@ -15371,9 +15350,6 @@ op_db: List[OpInfo] = [
|
||||
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', dtypes=(torch.complex128,)),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
|
||||
# stride mismatch
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cuda',
|
||||
dtypes=(torch.float32, torch.float64), active_if=not TEST_WITH_ROCM),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
|
||||
device_type='mps', dtypes=[torch.float32]),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
|
||||
@ -15418,8 +15394,6 @@ op_db: List[OpInfo] = [
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
|
||||
# stride mismatch
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cuda', active_if=not TEST_WITH_ROCM),
|
||||
)),
|
||||
OpInfo('pca_lowrank',
|
||||
op=lambda *args, **kwargs: wrapper_set_seed(
|
||||
@ -15444,8 +15418,6 @@ op_db: List[OpInfo] = [
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
|
||||
# stride mismatch
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta', device_type='cuda', active_if=not TEST_WITH_ROCM),
|
||||
)),
|
||||
BinaryUfuncInfo('polar',
|
||||
dtypes=floating_types(),
|
||||
@ -17423,7 +17395,7 @@ op_db: List[OpInfo] = [
|
||||
sample_inputs_func=sample_inputs_tensorsolve,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver],
|
||||
decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagma],
|
||||
),
|
||||
OpInfo(
|
||||
"nn.functional.mse_loss",
|
||||
|
@ -147,6 +147,10 @@ class TorchDispatchMode(metaclass=TorchDispatchModeMeta):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def push(cls, *args, **kwargs):
|
||||
return push_torch_dispatch_mode(functools.partial(cls, *args, **kwargs))
|
||||
|
||||
|
||||
class BaseTorchDispatchMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
|
Reference in New Issue
Block a user