mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Change default torch_function behavior to be disabled when torch_dispatch is defined (take 2) (#120632)
This does not introduce a new test but is tested by checking that all the classes we already have still behave as before now that they don't explicitly disable torch_function. Pull Request resolved: https://github.com/pytorch/pytorch/pull/120632 Approved by: https://github.com/ezyang
This commit is contained in:
@ -756,8 +756,6 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
class foo_autograd_fn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
|
@ -60,8 +60,6 @@ class TorchDispatchTensor(torch.Tensor):
|
||||
t.elem = elem
|
||||
return t
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
|
||||
|
@ -38,6 +38,12 @@ class Int16Tensor(torch.Tensor):
|
||||
out = tree_map(wrap, out)
|
||||
return out
|
||||
|
||||
# This most likely should be removed (and thus use the disabled impl)
|
||||
# but the test below fail under Dynamo in that case.
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
with no_dispatch():
|
||||
t16 = self.view(torch.int16)
|
||||
|
@ -9358,8 +9358,6 @@ class TestAutogradForwardMode(TestCase):
|
||||
def __new__(cls, data=None):
|
||||
return torch.Tensor._make_subclass(cls, data)
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
if func.overloadpacket == torch.ops.aten.alias:
|
||||
|
@ -864,8 +864,6 @@ $6: f32[1] = torch._ops.aten.add_.Tensor($1, $5)''')
|
||||
|
||||
def test_new_ones(self) -> None:
|
||||
class MyTensor(torch.Tensor):
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
return MyTensor(3)
|
||||
@ -874,8 +872,6 @@ $6: f32[1] = torch._ops.aten.add_.Tensor($1, $5)''')
|
||||
|
||||
def test_like(self) -> None:
|
||||
class MyTensor(torch.Tensor):
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
return MyTensor(3)
|
||||
@ -1055,7 +1051,6 @@ def forward(self, x_a_1, x_b_1, y_1):
|
||||
called_funcs = []
|
||||
|
||||
class MyTensor(torch.Tensor):
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
elem: torch.Tensor
|
||||
__slots__ = ['elem']
|
||||
|
||||
@ -1365,8 +1360,6 @@ $3: f32[] = torch._ops.aten.add.Tensor($1, $2)""")
|
||||
|
||||
return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
a = SubTensor(torch.randn(2))
|
||||
with PoliteMode() as mode:
|
||||
a.abs()
|
||||
@ -1601,8 +1594,6 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
|
||||
def __new__(cls, elem):
|
||||
return torch.Tensor._make_subclass(cls, elem)
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
called.append(func)
|
||||
@ -1621,8 +1612,6 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
|
||||
def __new__(cls, elem):
|
||||
return torch.Tensor._make_subclass(cls, elem)
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
called.append(func)
|
||||
@ -1640,8 +1629,6 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
|
||||
def __new__(cls, elem):
|
||||
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
called.append(func)
|
||||
@ -1662,8 +1649,6 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
|
||||
called = 0
|
||||
|
||||
class SubTensor(torch.Tensor):
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
nonlocal called
|
||||
@ -1698,8 +1683,6 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
|
||||
r = torch.Tensor._make_subclass(cls, elem)
|
||||
return r
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
with no_dispatch():
|
||||
|
@ -53,8 +53,6 @@ class IncorrectAliasTensor(torch.Tensor):
|
||||
|
||||
__slots__ = ['elem']
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem, *args, **kwargs):
|
||||
# The wrapping tensor (IncorrectAliasTensor) shouldn't hold any
|
||||
|
@ -235,8 +235,6 @@ class TestSubclass(TestCase):
|
||||
def __init__(self, t) -> None:
|
||||
self.tensor: torch.Tensor = t
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
|
||||
|
@ -670,8 +670,6 @@ class FakeTensor(torch.Tensor):
|
||||
out.append(s)
|
||||
return out
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorMetadata:
|
||||
|
@ -122,11 +122,6 @@ class FunctionalTensor(torch.Tensor):
|
||||
out.elem = elem
|
||||
return out
|
||||
|
||||
# Need to disable default torch_function. Why?
|
||||
# Default torch_function will always wrap outputs into a subclass if they aren't already a subclass.
|
||||
# We actually.. don't want to do this sometimes, see Note [FunctionalTensorMode inputs are sometimes plain tensors]
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
unrecognized_types = [
|
||||
t
|
||||
|
@ -2242,11 +2242,51 @@ int THPVariableMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
|
||||
((PyTypeObject*)cls)->tp_dealloc = (destructor)THPVariable_subclass_dealloc;
|
||||
((PyTypeObject*)cls)->tp_traverse =
|
||||
(traverseproc)THPVariable_subclass_traverse;
|
||||
|
||||
// Don't do anything for the base Tensor class
|
||||
if (!THPVariableClass) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// If the user provided a torch_dispatch implementation, disable
|
||||
// torch_function.
|
||||
py::object torch_dispatch_impl = py::reinterpret_steal<py::object>(
|
||||
PyObject_GetAttrString(cls, "__torch_dispatch__"));
|
||||
py::object torch_dispatch_default = py::reinterpret_steal<py::object>(
|
||||
PyObject_GetAttrString(THPVariableClass, "__torch_dispatch__"));
|
||||
if (torch_dispatch_impl.ptr() != torch_dispatch_default.ptr()) {
|
||||
py::object torch_function_impl = py::reinterpret_steal<py::object>(
|
||||
PyObject_GetAttrString(cls, "__torch_function__"));
|
||||
// This will only fail if the user subclasses _TensorBase directly.
|
||||
// Ignore the error here to let the class __init__ code fail with a nice
|
||||
// error message.
|
||||
if (!torch_function_impl) {
|
||||
PyErr_Clear();
|
||||
return 0;
|
||||
}
|
||||
py::object torch_function_default_bound = py::reinterpret_steal<py::object>(
|
||||
PyObject_GetAttrString(THPVariableClass, "__torch_function__"));
|
||||
|
||||
// Since our __torch_function__ is a classmethod, we need to "unbound" the
|
||||
// method to get the raw function
|
||||
py::object torch_function_default = py::reinterpret_steal<py::object>(
|
||||
PyObject_GetAttrString(torch_function_default_bound.ptr(), "__func__"));
|
||||
|
||||
// User-defined __torch_function__ might not be a classmethod
|
||||
if (PyObject_HasAttrString(torch_function_impl.ptr(), "__func__")) {
|
||||
torch_function_impl = py::reinterpret_steal<py::object>(
|
||||
PyObject_GetAttrString(torch_function_impl.ptr(), "__func__"));
|
||||
}
|
||||
if (torch_function_impl.ptr() == torch_function_default.ptr()) {
|
||||
PyObject_SetAttrString(
|
||||
cls, "__torch_function__", torch::disabled_torch_function_impl());
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
namespace torch::autograd {
|
||||
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
|
||||
extern PyMethodDef variable_methods[];
|
||||
@ -2268,8 +2308,7 @@ void initTensorImplConversion(PyObject* module) {
|
||||
return t->getIntrusivePtr().get();
|
||||
});
|
||||
}
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
} // namespace torch::autograd
|
||||
|
||||
bool THPVariable_initModule(PyObject* module) {
|
||||
THPVariableMetaType.tp_base = &PyType_Type;
|
||||
|
@ -550,8 +550,6 @@ class AsyncCollectiveTensor(torch.Tensor):
|
||||
|
||||
__slots__ = ["elem", "completed"]
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem: torch.Tensor):
|
||||
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
||||
|
@ -278,8 +278,6 @@ class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__
|
||||
stride=outer_stride,
|
||||
)
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@classmethod
|
||||
# pyre-fixme[3]: Return type must be annotated.
|
||||
# pyre-fixme[2]: Parameter must be annotated.
|
||||
|
@ -64,9 +64,6 @@ class DiagTensorBelow(WrapperTensor):
|
||||
|
||||
handled_ops = {}
|
||||
|
||||
# We disable torch function here to avoid any unwanted wrapping of the output
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
if not all(issubclass(cls, t) for t in types):
|
||||
|
@ -117,7 +117,6 @@ def generate_cct_and_mode(autograd_view_consistency=True):
|
||||
elem: torch.Tensor
|
||||
|
||||
__slots__ = ['elem']
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem, mode, *args, **kwargs):
|
||||
|
@ -50,8 +50,6 @@ class LoggingTensor(torch.Tensor):
|
||||
|
||||
context = contextlib.nullcontext
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem, *args, **kwargs):
|
||||
# The wrapping tensor (LoggingTensor) shouldn't hold any
|
||||
|
Reference in New Issue
Block a user