diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py index 3927673177ac..d23fec607afa 100644 --- a/test/dynamo/test_autograd_function.py +++ b/test/dynamo/test_autograd_function.py @@ -756,8 +756,6 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase): raise NotImplementedError() - __torch_function__ = torch._C._disabled_torch_function_impl - class foo_autograd_fn(torch.autograd.Function): @staticmethod def forward(ctx, x): diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py index 8301fae214f3..d15218b3922b 100644 --- a/test/profiler/test_profiler_tree.py +++ b/test/profiler/test_profiler_tree.py @@ -60,8 +60,6 @@ class TorchDispatchTensor(torch.Tensor): t.elem = elem return t - __torch_function__ = torch._C._disabled_torch_function_impl - @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): diff --git a/test/quantization/core/experimental/test_bits.py b/test/quantization/core/experimental/test_bits.py index bf7f3812744b..dfba754590d8 100644 --- a/test/quantization/core/experimental/test_bits.py +++ b/test/quantization/core/experimental/test_bits.py @@ -38,6 +38,12 @@ class Int16Tensor(torch.Tensor): out = tree_map(wrap, out) return out + # This most likely should be removed (and thus use the disabled impl) + # but the test below fail under Dynamo in that case. + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + return super().__torch_function__(func, types, args, kwargs) + def __repr__(self) -> str: with no_dispatch(): t16 = self.view(torch.int16) diff --git a/test/test_autograd.py b/test/test_autograd.py index 326a493758b4..90aa883dbd91 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -9358,8 +9358,6 @@ class TestAutogradForwardMode(TestCase): def __new__(cls, data=None): return torch.Tensor._make_subclass(cls, data) - __torch_function__ = torch._C._disabled_torch_function_impl - @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): if func.overloadpacket == torch.ops.aten.alias: diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index c04f724227b9..04c263591fd1 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -864,8 +864,6 @@ $6: f32[1] = torch._ops.aten.add_.Tensor($1, $5)''') def test_new_ones(self) -> None: class MyTensor(torch.Tensor): - __torch_function__ = torch._C._disabled_torch_function_impl - @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): return MyTensor(3) @@ -874,8 +872,6 @@ $6: f32[1] = torch._ops.aten.add_.Tensor($1, $5)''') def test_like(self) -> None: class MyTensor(torch.Tensor): - __torch_function__ = torch._C._disabled_torch_function_impl - @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): return MyTensor(3) @@ -1055,7 +1051,6 @@ def forward(self, x_a_1, x_b_1, y_1): called_funcs = [] class MyTensor(torch.Tensor): - __torch_function__ = torch._C._disabled_torch_function_impl elem: torch.Tensor __slots__ = ['elem'] @@ -1365,8 +1360,6 @@ $3: f32[] = torch._ops.aten.add.Tensor($1, $2)""") return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) - __torch_function__ = torch._C._disabled_torch_function_impl - a = SubTensor(torch.randn(2)) with PoliteMode() as mode: a.abs() @@ -1601,8 +1594,6 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p def __new__(cls, elem): return torch.Tensor._make_subclass(cls, elem) - __torch_function__ = torch._C._disabled_torch_function_impl - @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): called.append(func) @@ -1621,8 +1612,6 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p def __new__(cls, elem): return torch.Tensor._make_subclass(cls, elem) - __torch_function__ = torch._C._disabled_torch_function_impl - @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): called.append(func) @@ -1640,8 +1629,6 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p def __new__(cls, elem): return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) - __torch_function__ = torch._C._disabled_torch_function_impl - @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): called.append(func) @@ -1662,8 +1649,6 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p called = 0 class SubTensor(torch.Tensor): - __torch_function__ = torch._C._disabled_torch_function_impl - @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): nonlocal called @@ -1698,8 +1683,6 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p r = torch.Tensor._make_subclass(cls, elem) return r - __torch_function__ = torch._C._disabled_torch_function_impl - @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): with no_dispatch(): diff --git a/test/test_schema_check.py b/test/test_schema_check.py index 7233e5f54a9d..831ba9b89250 100644 --- a/test/test_schema_check.py +++ b/test/test_schema_check.py @@ -53,8 +53,6 @@ class IncorrectAliasTensor(torch.Tensor): __slots__ = ['elem'] - __torch_function__ = torch._C._disabled_torch_function_impl - @staticmethod def __new__(cls, elem, *args, **kwargs): # The wrapping tensor (IncorrectAliasTensor) shouldn't hold any diff --git a/test/test_subclass.py b/test/test_subclass.py index 01253955ad9c..869982c2830a 100644 --- a/test/test_subclass.py +++ b/test/test_subclass.py @@ -235,8 +235,6 @@ class TestSubclass(TestCase): def __init__(self, t) -> None: self.tensor: torch.Tensor = t - __torch_function__ = torch._C._disabled_torch_function_impl - @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 2f525d6454d2..08007dcfdd11 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -670,8 +670,6 @@ class FakeTensor(torch.Tensor): out.append(s) return out - __torch_function__ = torch._C._disabled_torch_function_impl - @dataclass(frozen=True) class TensorMetadata: diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index 3152e99fae66..ce346b81401a 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -122,11 +122,6 @@ class FunctionalTensor(torch.Tensor): out.elem = elem return out - # Need to disable default torch_function. Why? - # Default torch_function will always wrap outputs into a subclass if they aren't already a subclass. - # We actually.. don't want to do this sometimes, see Note [FunctionalTensorMode inputs are sometimes plain tensors] - __torch_function__ = torch._C._disabled_torch_function_impl - def __torch_dispatch__(self, func, types, args=(), kwargs=None): unrecognized_types = [ t diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 27712b91fc58..0b8f36135c8d 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -2242,11 +2242,51 @@ int THPVariableMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) { ((PyTypeObject*)cls)->tp_dealloc = (destructor)THPVariable_subclass_dealloc; ((PyTypeObject*)cls)->tp_traverse = (traverseproc)THPVariable_subclass_traverse; + + // Don't do anything for the base Tensor class + if (!THPVariableClass) { + return 0; + } + + // If the user provided a torch_dispatch implementation, disable + // torch_function. + py::object torch_dispatch_impl = py::reinterpret_steal( + PyObject_GetAttrString(cls, "__torch_dispatch__")); + py::object torch_dispatch_default = py::reinterpret_steal( + PyObject_GetAttrString(THPVariableClass, "__torch_dispatch__")); + if (torch_dispatch_impl.ptr() != torch_dispatch_default.ptr()) { + py::object torch_function_impl = py::reinterpret_steal( + PyObject_GetAttrString(cls, "__torch_function__")); + // This will only fail if the user subclasses _TensorBase directly. + // Ignore the error here to let the class __init__ code fail with a nice + // error message. + if (!torch_function_impl) { + PyErr_Clear(); + return 0; + } + py::object torch_function_default_bound = py::reinterpret_steal( + PyObject_GetAttrString(THPVariableClass, "__torch_function__")); + + // Since our __torch_function__ is a classmethod, we need to "unbound" the + // method to get the raw function + py::object torch_function_default = py::reinterpret_steal( + PyObject_GetAttrString(torch_function_default_bound.ptr(), "__func__")); + + // User-defined __torch_function__ might not be a classmethod + if (PyObject_HasAttrString(torch_function_impl.ptr(), "__func__")) { + torch_function_impl = py::reinterpret_steal( + PyObject_GetAttrString(torch_function_impl.ptr(), "__func__")); + } + if (torch_function_impl.ptr() == torch_function_default.ptr()) { + PyObject_SetAttrString( + cls, "__torch_function__", torch::disabled_torch_function_impl()); + } + } + return 0; } -namespace torch { -namespace autograd { +namespace torch::autograd { // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) extern PyMethodDef variable_methods[]; @@ -2268,8 +2308,7 @@ void initTensorImplConversion(PyObject* module) { return t->getIntrusivePtr().get(); }); } -} // namespace autograd -} // namespace torch +} // namespace torch::autograd bool THPVariable_initModule(PyObject* module) { THPVariableMetaType.tp_base = &PyType_Type; diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 2cd48a06a6c1..2c802f2bcdeb 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -550,8 +550,6 @@ class AsyncCollectiveTensor(torch.Tensor): __slots__ = ["elem", "completed"] - __torch_function__ = torch._C._disabled_torch_function_impl - @staticmethod def __new__(cls, elem: torch.Tensor): r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py index 1e6cef1c98f1..4336e4627810 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -278,8 +278,6 @@ class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__ stride=outer_stride, ) - __torch_function__ = torch._C._disabled_torch_function_impl - @classmethod # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. diff --git a/torch/testing/_internal/common_subclass.py b/torch/testing/_internal/common_subclass.py index 5230d9181309..f6a8ed065cb8 100644 --- a/torch/testing/_internal/common_subclass.py +++ b/torch/testing/_internal/common_subclass.py @@ -64,9 +64,6 @@ class DiagTensorBelow(WrapperTensor): handled_ops = {} - # We disable torch function here to avoid any unwanted wrapping of the output - __torch_function__ = torch._C._disabled_torch_function_impl - @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): if not all(issubclass(cls, t) for t in types): diff --git a/torch/testing/_internal/composite_compliance.py b/torch/testing/_internal/composite_compliance.py index d9976cd261fb..b3c3bd4a130e 100644 --- a/torch/testing/_internal/composite_compliance.py +++ b/torch/testing/_internal/composite_compliance.py @@ -117,7 +117,6 @@ def generate_cct_and_mode(autograd_view_consistency=True): elem: torch.Tensor __slots__ = ['elem'] - __torch_function__ = torch._C._disabled_torch_function_impl @staticmethod def __new__(cls, elem, mode, *args, **kwargs): diff --git a/torch/testing/_internal/logging_tensor.py b/torch/testing/_internal/logging_tensor.py index c3ce8648d955..dedb83343e5d 100644 --- a/torch/testing/_internal/logging_tensor.py +++ b/torch/testing/_internal/logging_tensor.py @@ -50,8 +50,6 @@ class LoggingTensor(torch.Tensor): context = contextlib.nullcontext - __torch_function__ = torch._C._disabled_torch_function_impl - @staticmethod def __new__(cls, elem, *args, **kwargs): # The wrapping tensor (LoggingTensor) shouldn't hold any