Change default torch_function behavior to be disabled when torch_dispatch is defined (take 2) (#120632)

This does not introduce a new test but is tested by checking that all the classes we already have still behave as before now that they don't explicitly disable torch_function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120632
Approved by: https://github.com/ezyang
This commit is contained in:
albanD
2024-03-09 01:08:32 +00:00
committed by PyTorch MergeBot
parent ca9678405a
commit 6791b0c09e
15 changed files with 49 additions and 48 deletions

View File

@ -756,8 +756,6 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
raise NotImplementedError() raise NotImplementedError()
__torch_function__ = torch._C._disabled_torch_function_impl
class foo_autograd_fn(torch.autograd.Function): class foo_autograd_fn(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x): def forward(ctx, x):

View File

@ -60,8 +60,6 @@ class TorchDispatchTensor(torch.Tensor):
t.elem = elem t.elem = elem
return t return t
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):

View File

@ -38,6 +38,12 @@ class Int16Tensor(torch.Tensor):
out = tree_map(wrap, out) out = tree_map(wrap, out)
return out return out
# This most likely should be removed (and thus use the disabled impl)
# but the test below fail under Dynamo in that case.
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
return super().__torch_function__(func, types, args, kwargs)
def __repr__(self) -> str: def __repr__(self) -> str:
with no_dispatch(): with no_dispatch():
t16 = self.view(torch.int16) t16 = self.view(torch.int16)

View File

@ -9358,8 +9358,6 @@ class TestAutogradForwardMode(TestCase):
def __new__(cls, data=None): def __new__(cls, data=None):
return torch.Tensor._make_subclass(cls, data) return torch.Tensor._make_subclass(cls, data)
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if func.overloadpacket == torch.ops.aten.alias: if func.overloadpacket == torch.ops.aten.alias:

View File

@ -864,8 +864,6 @@ $6: f32[1] = torch._ops.aten.add_.Tensor($1, $5)''')
def test_new_ones(self) -> None: def test_new_ones(self) -> None:
class MyTensor(torch.Tensor): class MyTensor(torch.Tensor):
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
return MyTensor(3) return MyTensor(3)
@ -874,8 +872,6 @@ $6: f32[1] = torch._ops.aten.add_.Tensor($1, $5)''')
def test_like(self) -> None: def test_like(self) -> None:
class MyTensor(torch.Tensor): class MyTensor(torch.Tensor):
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
return MyTensor(3) return MyTensor(3)
@ -1055,7 +1051,6 @@ def forward(self, x_a_1, x_b_1, y_1):
called_funcs = [] called_funcs = []
class MyTensor(torch.Tensor): class MyTensor(torch.Tensor):
__torch_function__ = torch._C._disabled_torch_function_impl
elem: torch.Tensor elem: torch.Tensor
__slots__ = ['elem'] __slots__ = ['elem']
@ -1365,8 +1360,6 @@ $3: f32[] = torch._ops.aten.add.Tensor($1, $2)""")
return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
__torch_function__ = torch._C._disabled_torch_function_impl
a = SubTensor(torch.randn(2)) a = SubTensor(torch.randn(2))
with PoliteMode() as mode: with PoliteMode() as mode:
a.abs() a.abs()
@ -1601,8 +1594,6 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
def __new__(cls, elem): def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem) return torch.Tensor._make_subclass(cls, elem)
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
called.append(func) called.append(func)
@ -1621,8 +1612,6 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
def __new__(cls, elem): def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem) return torch.Tensor._make_subclass(cls, elem)
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
called.append(func) called.append(func)
@ -1640,8 +1629,6 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
def __new__(cls, elem): def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
called.append(func) called.append(func)
@ -1662,8 +1649,6 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
called = 0 called = 0
class SubTensor(torch.Tensor): class SubTensor(torch.Tensor):
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
nonlocal called nonlocal called
@ -1698,8 +1683,6 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
r = torch.Tensor._make_subclass(cls, elem) r = torch.Tensor._make_subclass(cls, elem)
return r return r
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
with no_dispatch(): with no_dispatch():

View File

@ -53,8 +53,6 @@ class IncorrectAliasTensor(torch.Tensor):
__slots__ = ['elem'] __slots__ = ['elem']
__torch_function__ = torch._C._disabled_torch_function_impl
@staticmethod @staticmethod
def __new__(cls, elem, *args, **kwargs): def __new__(cls, elem, *args, **kwargs):
# The wrapping tensor (IncorrectAliasTensor) shouldn't hold any # The wrapping tensor (IncorrectAliasTensor) shouldn't hold any

View File

@ -235,8 +235,6 @@ class TestSubclass(TestCase):
def __init__(self, t) -> None: def __init__(self, t) -> None:
self.tensor: torch.Tensor = t self.tensor: torch.Tensor = t
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):

View File

@ -670,8 +670,6 @@ class FakeTensor(torch.Tensor):
out.append(s) out.append(s)
return out return out
__torch_function__ = torch._C._disabled_torch_function_impl
@dataclass(frozen=True) @dataclass(frozen=True)
class TensorMetadata: class TensorMetadata:

View File

@ -122,11 +122,6 @@ class FunctionalTensor(torch.Tensor):
out.elem = elem out.elem = elem
return out return out
# Need to disable default torch_function. Why?
# Default torch_function will always wrap outputs into a subclass if they aren't already a subclass.
# We actually.. don't want to do this sometimes, see Note [FunctionalTensorMode inputs are sometimes plain tensors]
__torch_function__ = torch._C._disabled_torch_function_impl
def __torch_dispatch__(self, func, types, args=(), kwargs=None): def __torch_dispatch__(self, func, types, args=(), kwargs=None):
unrecognized_types = [ unrecognized_types = [
t t

View File

@ -2242,11 +2242,51 @@ int THPVariableMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
((PyTypeObject*)cls)->tp_dealloc = (destructor)THPVariable_subclass_dealloc; ((PyTypeObject*)cls)->tp_dealloc = (destructor)THPVariable_subclass_dealloc;
((PyTypeObject*)cls)->tp_traverse = ((PyTypeObject*)cls)->tp_traverse =
(traverseproc)THPVariable_subclass_traverse; (traverseproc)THPVariable_subclass_traverse;
// Don't do anything for the base Tensor class
if (!THPVariableClass) {
return 0;
}
// If the user provided a torch_dispatch implementation, disable
// torch_function.
py::object torch_dispatch_impl = py::reinterpret_steal<py::object>(
PyObject_GetAttrString(cls, "__torch_dispatch__"));
py::object torch_dispatch_default = py::reinterpret_steal<py::object>(
PyObject_GetAttrString(THPVariableClass, "__torch_dispatch__"));
if (torch_dispatch_impl.ptr() != torch_dispatch_default.ptr()) {
py::object torch_function_impl = py::reinterpret_steal<py::object>(
PyObject_GetAttrString(cls, "__torch_function__"));
// This will only fail if the user subclasses _TensorBase directly.
// Ignore the error here to let the class __init__ code fail with a nice
// error message.
if (!torch_function_impl) {
PyErr_Clear();
return 0;
}
py::object torch_function_default_bound = py::reinterpret_steal<py::object>(
PyObject_GetAttrString(THPVariableClass, "__torch_function__"));
// Since our __torch_function__ is a classmethod, we need to "unbound" the
// method to get the raw function
py::object torch_function_default = py::reinterpret_steal<py::object>(
PyObject_GetAttrString(torch_function_default_bound.ptr(), "__func__"));
// User-defined __torch_function__ might not be a classmethod
if (PyObject_HasAttrString(torch_function_impl.ptr(), "__func__")) {
torch_function_impl = py::reinterpret_steal<py::object>(
PyObject_GetAttrString(torch_function_impl.ptr(), "__func__"));
}
if (torch_function_impl.ptr() == torch_function_default.ptr()) {
PyObject_SetAttrString(
cls, "__torch_function__", torch::disabled_torch_function_impl());
}
}
return 0; return 0;
} }
namespace torch { namespace torch::autograd {
namespace autograd {
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
extern PyMethodDef variable_methods[]; extern PyMethodDef variable_methods[];
@ -2268,8 +2308,7 @@ void initTensorImplConversion(PyObject* module) {
return t->getIntrusivePtr().get(); return t->getIntrusivePtr().get();
}); });
} }
} // namespace autograd } // namespace torch::autograd
} // namespace torch
bool THPVariable_initModule(PyObject* module) { bool THPVariable_initModule(PyObject* module) {
THPVariableMetaType.tp_base = &PyType_Type; THPVariableMetaType.tp_base = &PyType_Type;

View File

@ -550,8 +550,6 @@ class AsyncCollectiveTensor(torch.Tensor):
__slots__ = ["elem", "completed"] __slots__ = ["elem", "completed"]
__torch_function__ = torch._C._disabled_torch_function_impl
@staticmethod @staticmethod
def __new__(cls, elem: torch.Tensor): def __new__(cls, elem: torch.Tensor):
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]

View File

@ -278,8 +278,6 @@ class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__
stride=outer_stride, stride=outer_stride,
) )
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod @classmethod
# pyre-fixme[3]: Return type must be annotated. # pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated. # pyre-fixme[2]: Parameter must be annotated.

View File

@ -64,9 +64,6 @@ class DiagTensorBelow(WrapperTensor):
handled_ops = {} handled_ops = {}
# We disable torch function here to avoid any unwanted wrapping of the output
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if not all(issubclass(cls, t) for t in types): if not all(issubclass(cls, t) for t in types):

View File

@ -117,7 +117,6 @@ def generate_cct_and_mode(autograd_view_consistency=True):
elem: torch.Tensor elem: torch.Tensor
__slots__ = ['elem'] __slots__ = ['elem']
__torch_function__ = torch._C._disabled_torch_function_impl
@staticmethod @staticmethod
def __new__(cls, elem, mode, *args, **kwargs): def __new__(cls, elem, mode, *args, **kwargs):

View File

@ -50,8 +50,6 @@ class LoggingTensor(torch.Tensor):
context = contextlib.nullcontext context = contextlib.nullcontext
__torch_function__ = torch._C._disabled_torch_function_impl
@staticmethod @staticmethod
def __new__(cls, elem, *args, **kwargs): def __new__(cls, elem, *args, **kwargs):
# The wrapping tensor (LoggingTensor) shouldn't hold any # The wrapping tensor (LoggingTensor) shouldn't hold any