mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Update impl_abstract_pystub to be less boilerplatey (#112851)"
This reverts commit 6ae4e3a8d249a96d9a8bbfba389d0509783e11e1. Reverted https://github.com/pytorch/pytorch/pull/112851 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/112851#issuecomment-1799539354))
This commit is contained in:
@ -265,15 +265,6 @@ AbstractImplPyStubsType& abstractImplPyStubsSingleton() {
|
||||
|
||||
}
|
||||
|
||||
c10::optional<std::pair<const char*, const char*>> Dispatcher::getAbstractImplPyStub(OperatorName op_name) {
|
||||
std::lock_guard<std::mutex> lock(guard_->mutex);
|
||||
auto found = abstractImplPyStubsSingleton().find(op_name);
|
||||
if (found == abstractImplPyStubsSingleton().end()) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
return found->second;
|
||||
}
|
||||
|
||||
RegistrationHandleRAII Dispatcher::registerAbstractImplPyStub(
|
||||
const OperatorName& op_name,
|
||||
const char* pymodule,
|
||||
@ -314,9 +305,9 @@ void Dispatcher::throwIfHasAbstractImplPyStub(OperatorName op_name) {
|
||||
interpreter != nullptr,
|
||||
op_name,
|
||||
": while attempting to run this operator with Meta Tensors: "
|
||||
"Either there is no meta kernel for this operator, or it is located "
|
||||
"in the python module ", pymodule, " which is not available "
|
||||
"because Python isn't available.")
|
||||
"the abstract impl for this operator (necessary for Meta Tensors) "
|
||||
"was declared to exist in the Python module ", pymodule,
|
||||
" but Python is not available.");
|
||||
(*interpreter)->throw_abstract_impl_not_imported_error(toString(op_name), pymodule, context);
|
||||
}
|
||||
|
||||
|
@ -234,8 +234,6 @@ public:
|
||||
*/
|
||||
void throwIfHasAbstractImplPyStub(OperatorName op_name);
|
||||
|
||||
c10::optional<std::pair<const char*, const char*>> getAbstractImplPyStub(OperatorName op_name);
|
||||
|
||||
/**
|
||||
* Register a new operator by name.
|
||||
*/
|
||||
|
@ -128,14 +128,6 @@ Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name
|
||||
}
|
||||
switch (rv) {
|
||||
case _RegisterOrVerify::REGISTER:
|
||||
if (impl_abstract_pystub_.has_value()) {
|
||||
registrars_.emplace_back(
|
||||
c10::Dispatcher::singleton().registerAbstractImplPyStub(
|
||||
schema.operator_name(),
|
||||
impl_abstract_pystub_->first,
|
||||
impl_abstract_pystub_->second)
|
||||
);
|
||||
}
|
||||
registrars_.emplace_back(
|
||||
c10::Dispatcher::singleton().registerDef(
|
||||
std::move(schema),
|
||||
|
@ -1,13 +1,14 @@
|
||||
import torch
|
||||
import torch._custom_ops as library
|
||||
from model import get_custom_op_library_path
|
||||
|
||||
torch.ops.load_library(get_custom_op_library_path())
|
||||
|
||||
|
||||
@torch.library.impl_abstract("custom::nonzero")
|
||||
@library.impl_abstract("custom::nonzero")
|
||||
def nonzero_abstract(x):
|
||||
n = x.dim()
|
||||
ctx = torch.library.get_ctx()
|
||||
ctx = library.get_ctx()
|
||||
nnz = ctx.create_unbacked_symint()
|
||||
shape = [nnz, n]
|
||||
return x.new_empty(shape, dtype=torch.long)
|
||||
|
@ -1,9 +1,10 @@
|
||||
import torch
|
||||
import torch._custom_ops as library
|
||||
from model import get_custom_op_library_path
|
||||
|
||||
torch.ops.load_library(get_custom_op_library_path())
|
||||
|
||||
|
||||
@torch.library.impl_abstract("custom::sin")
|
||||
@library.impl_abstract("custom::sin")
|
||||
def sin_abstract(x):
|
||||
return torch.empty_like(x)
|
||||
|
@ -73,22 +73,14 @@ torch::Tensor custom_sin(torch::Tensor x) {
|
||||
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(custom, m) {
|
||||
m.impl_abstract_pystub("my_custom_ops2");
|
||||
m.def("op", custom_op);
|
||||
m.def("op2", custom_op2);
|
||||
m.def("op_with_defaults(Tensor tensor, float scalar = 1, int repeat = 1) -> Tensor[]", custom_op);
|
||||
m.def("op_with_autograd(Tensor var1, int mul, Tensor var2, Tensor? var3=None) -> Tensor", custom_op_with_autograd);
|
||||
m.def("sin(Tensor x) -> Tensor");
|
||||
m.def("cos(Tensor x) -> Tensor");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(custom, m) {
|
||||
m.impl_abstract_pystub("my_custom_ops");
|
||||
m.impl_abstract_pystub("sin", "my_custom_ops2");
|
||||
m.def("nonzero(Tensor x) -> Tensor");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(custom, m) {
|
||||
m.def("tan(Tensor x) -> Tensor");
|
||||
m.impl_abstract_pystub("nonzero", "my_custom_ops");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(custom, CPU, m) {
|
||||
|
@ -1,17 +0,0 @@
|
||||
import torch
|
||||
from model import get_custom_op_library_path
|
||||
|
||||
torch.ops.load_library(get_custom_op_library_path())
|
||||
|
||||
|
||||
# NB: The impl_abstract_pystub for cos actually
|
||||
# specifies it should live in the my_custom_ops2 module.
|
||||
@torch.library.impl_abstract("custom::cos")
|
||||
def cos_abstract(x):
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
# NB: There is no impl_abstract_pystub for tan
|
||||
@torch.library.impl_abstract("custom::tan")
|
||||
def tan_abstract(x):
|
||||
return torch.empty_like(x)
|
@ -10,7 +10,6 @@ from torch import ops
|
||||
from model import Model, get_custom_op_library_path
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
torch.ops.import_module("pointwise")
|
||||
|
||||
class TestCustomOperators(TestCase):
|
||||
def setUp(self):
|
||||
@ -20,16 +19,6 @@ class TestCustomOperators(TestCase):
|
||||
def test_custom_library_is_loaded(self):
|
||||
self.assertIn(self.library_path, ops.loaded_libraries)
|
||||
|
||||
def test_op_with_no_abstract_impl_pystub(self):
|
||||
x = torch.randn(3, device='meta')
|
||||
with self.assertRaisesRegex(RuntimeError, "pointwise"):
|
||||
torch.ops.custom.tan(x)
|
||||
|
||||
def test_op_with_incorrect_abstract_impl_pystub(self):
|
||||
x = torch.randn(3, device='meta')
|
||||
with self.assertRaisesRegex(RuntimeError, "pointwise"):
|
||||
torch.ops.custom.cos(x)
|
||||
|
||||
def test_abstract_impl_pystub_faketensor(self):
|
||||
from functorch import make_fx
|
||||
x = torch.randn(3, device='cpu')
|
||||
|
@ -1776,20 +1776,6 @@ def forward(self, x_1):
|
||||
y = self.ns().foo(x)
|
||||
assert torch.allclose(y, x.sin())
|
||||
|
||||
def test_defined_in_python(self):
|
||||
self.assertFalse(torch.ops.aten.sin.default._defined_in_python)
|
||||
self.assertFalse(torch.ops.aten.sum.dim_IntList._defined_in_python)
|
||||
|
||||
lib = self.lib()
|
||||
torch.library.define("{self._test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
|
||||
ns = self.ns()
|
||||
self.assertTrue(ns.foo.default._defined_in_python)
|
||||
|
||||
torch.library.define(
|
||||
"{self._test_ns}::bar.overload", "(Tensor x) -> Tensor", lib=lib
|
||||
)
|
||||
self.assertTrue(ns.bar.overload._defined_in_python)
|
||||
|
||||
def _test_impl_device(self, name, types, device):
|
||||
lib = self.lib()
|
||||
torch.library.define(f"{self.test_ns}::{name}", "(Tensor x) -> Tensor", lib=lib)
|
||||
|
@ -1451,7 +1451,6 @@ def _dispatch_key_for_device(device_type: str) -> str: ...
|
||||
def _parse_dispatch_key(key: str) -> Optional[DispatchKey]: ...
|
||||
def _dispatch_key_parse(dispatch: _dispatchkey) -> DispatchKey: ...
|
||||
def _dispatch_num_backends() -> _int: ...
|
||||
def _dispatch_pystub(name: str, overload: str) -> Optional[Tuple[str, str]]: ...
|
||||
def _dispatch_is_alias_key(dispatch: _dispatchkey) -> _bool: ...
|
||||
def _functionality_to_backend_keys(dispatch: _dispatchkey) -> List[DispatchKey]: ...
|
||||
def _functionalization_reapply_views_tls() -> _bool: ...
|
||||
|
@ -487,9 +487,6 @@ class OpOverload(OperatorBase):
|
||||
self.__qualname__ = self._name
|
||||
self.__annotations__ = {}
|
||||
|
||||
# If the OpOverload was constructed from a Library.def in Python.
|
||||
self._defined_in_python = self.__qualname__ in torch.library._defs
|
||||
|
||||
# Logic replicated from aten/src/ATen/native/MathBitsFallback.h
|
||||
is_write = None
|
||||
for a in self._schema.arguments:
|
||||
|
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Any, Dict
|
||||
|
||||
@ -49,16 +48,11 @@ def resolve_library_path(path: str) -> str:
|
||||
|
||||
|
||||
def throw_abstract_impl_not_imported_error(opname, module, context):
|
||||
if module in sys.modules:
|
||||
raise NotImplementedError(
|
||||
f"{opname}: We could not find the abstract impl for this operator. "
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{opname}: We could not find the abstract impl for this operator. "
|
||||
f"The operator specified that you may need to import the '{module}' "
|
||||
f"Python module to load the abstract impl. {context}"
|
||||
)
|
||||
raise NotImplementedError(
|
||||
f"{opname}: We could not find the abstract impl for this operator. "
|
||||
f"The operator specified that you need to import the '{module}' Python "
|
||||
f"module to load the abstract impl. {context}"
|
||||
)
|
||||
|
||||
|
||||
# Meta only, see
|
||||
|
@ -773,10 +773,6 @@ void initDispatchBindings(PyObject* module) {
|
||||
|
||||
m.def(
|
||||
"_dispatch_is_main_interpreter", []() { return isMainPyInterpreter(); });
|
||||
m.def("_dispatch_pystub", [](const char* name, const char* overload) {
|
||||
return c10::Dispatcher::singleton().getAbstractImplPyStub(
|
||||
c10::OperatorName(name, overload));
|
||||
});
|
||||
|
||||
m.def("_replace_", [](const at::Tensor& a, const at::Tensor& b) {
|
||||
return at::functionalization::impl::replace_(a, b);
|
||||
|
@ -93,8 +93,6 @@ enum class _RegisterOrVerify { REGISTER, VERIFY };
|
||||
template <class CurClass>
|
||||
class class_;
|
||||
|
||||
#define HAS_IMPL_ABSTRACT_PYSTUB
|
||||
|
||||
/// Represents a C++ function that implements an operator. Most users won't
|
||||
/// interact directly with this class, except via error messages: the
|
||||
/// constructors this function define the set of permissible "function"-like
|
||||
@ -608,16 +606,19 @@ class TORCH_API Library final {
|
||||
return _def(std::move(s), nullptr, tags, rv);
|
||||
}
|
||||
|
||||
/// Declares that for all operators that are subsequently def'ed, their
|
||||
/// abstract impls may be found in the given Python module (pymodule).
|
||||
/// This registers some help text that is used if the abstract impl
|
||||
/// cannot be found.
|
||||
/// Declares that an operator (given by name) has an abstract impl in a
|
||||
/// Python module (pymodule). If the abstract impl was not yet imported,
|
||||
/// we will warn about it.
|
||||
///
|
||||
/// Args:
|
||||
/// - name: the name of the operator
|
||||
/// - pymodule: the python module
|
||||
/// - context: We may include this in the error message.
|
||||
Library& impl_abstract_pystub(const char* pymodule, const char* context = "") {
|
||||
impl_abstract_pystub_ = {pymodule, context};
|
||||
Library& impl_abstract_pystub(const char* name, const char* pymodule, const char* context = "") {
|
||||
at::OperatorName opname = _parseNameForLib(name);
|
||||
registrars_.emplace_back(
|
||||
c10::Dispatcher::singleton().registerAbstractImplPyStub(opname, pymodule, context)
|
||||
);
|
||||
return *this;
|
||||
}
|
||||
|
||||
@ -839,7 +840,6 @@ class TORCH_API Library final {
|
||||
Kind kind_;
|
||||
c10::optional<std::string> ns_;
|
||||
c10::optional<c10::DispatchKey> dispatch_key_;
|
||||
c10::optional<std::pair<const char*, const char*>> impl_abstract_pystub_;
|
||||
const char* file_;
|
||||
uint32_t line_;
|
||||
|
||||
|
@ -4,7 +4,6 @@ import traceback
|
||||
import torch
|
||||
import weakref
|
||||
import functools
|
||||
import inspect
|
||||
import re
|
||||
|
||||
__all__ = [
|
||||
@ -21,7 +20,6 @@ __all__ = [
|
||||
# This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid
|
||||
# libraries calling into kernels not intended to be called.
|
||||
_impls: Set[str] = set()
|
||||
_defs: Set[str] = set()
|
||||
|
||||
# prim is reserved by TorchScript interpreter
|
||||
_reserved_namespaces = ['prim']
|
||||
@ -61,7 +59,6 @@ class Library:
|
||||
filename, lineno = frame.filename, frame.lineno
|
||||
self.m: Optional[Any] = torch._C._dispatch_library(kind, ns, dispatch_key, filename, lineno)
|
||||
self.ns = ns
|
||||
self._op_defs: Set[str] = set()
|
||||
self._op_impls: Set[str] = set()
|
||||
self._registration_handles: List["torch._library.utils.RegistrationHandle"] = []
|
||||
self.kind = kind
|
||||
@ -70,7 +67,7 @@ class Library:
|
||||
# Python __del__ can lead to weird things (globals and locals may already
|
||||
# be gone when __del__ actually gets called!). finalizers help the
|
||||
# situation because it lets us capture references and keeps them alive
|
||||
weakref.finalize(self, _del_library, _impls, self._op_impls, _defs, self._op_defs, self._registration_handles)
|
||||
weakref.finalize(self, _del_library, _impls, self._op_impls, self._registration_handles)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>"
|
||||
@ -102,11 +99,7 @@ class Library:
|
||||
assert self.m is not None
|
||||
if isinstance(tags, torch.Tag):
|
||||
tags = (tags,)
|
||||
result = self.m.define(schema, alias_analysis, tuple(tags))
|
||||
qualname = self.ns + "::" + schema.split("(")[0]
|
||||
self._op_defs.add(qualname)
|
||||
_defs.add(qualname)
|
||||
return result
|
||||
return self.m.define(schema, alias_analysis, tuple(tags))
|
||||
|
||||
def impl(self, op_name, fn, dispatch_key=''):
|
||||
r'''Registers the function implementation for an operator defined in the library.
|
||||
@ -176,9 +169,8 @@ class Library:
|
||||
self._registration_handles.clear()
|
||||
|
||||
|
||||
def _del_library(captured_impls, op_impls, captured_defs, op_defs, registration_handles):
|
||||
def _del_library(captured_impls, op_impls, registration_handles):
|
||||
captured_impls -= op_impls
|
||||
captured_defs -= op_defs
|
||||
for handle in registration_handles:
|
||||
handle.destroy()
|
||||
|
||||
@ -417,21 +409,12 @@ def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
|
||||
>>> return torch.tensor(res, device=x.device)
|
||||
|
||||
"""
|
||||
|
||||
source = torch._library.utils.get_source(_stacklevel + 1)
|
||||
frame = inspect.stack()[_stacklevel]
|
||||
caller_module = inspect.getmodule(frame[0])
|
||||
# Can be none if you call impl_abstract from somewhere there isn't a module
|
||||
# (e.g. __main__)
|
||||
caller_module_name = None if caller_module is None else caller_module.__name__
|
||||
|
||||
def inner(func):
|
||||
entry = torch._library.simple_registry.singleton.find(qualname)
|
||||
if caller_module_name is not None:
|
||||
func_to_register = _check_pystubs_once(func, qualname, caller_module_name)
|
||||
else:
|
||||
func_to_register = func
|
||||
|
||||
handle = entry.abstract_impl.register(func_to_register, source)
|
||||
handle = entry.abstract_impl.register(func, source)
|
||||
if lib is not None:
|
||||
lib._registration_handles.append(handle)
|
||||
return func
|
||||
@ -441,45 +424,6 @@ def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
|
||||
return inner(func)
|
||||
|
||||
|
||||
# If the op was defined in C++, then we want to make sure there was an
|
||||
# m.impl_abstract_pystub(module, ...) call and that the module is the
|
||||
# same as the module that called torch.library.impl_abstract.
|
||||
def _check_pystubs_once(func, qualname, actual_module_name):
|
||||
checked = False
|
||||
|
||||
def inner(*args, **kwargs):
|
||||
nonlocal checked
|
||||
if checked:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
op = torch._library.utils.lookup_op(qualname)
|
||||
if op._defined_in_python:
|
||||
checked = True
|
||||
return func(*args, **kwargs)
|
||||
|
||||
maybe_pystub = torch._C._dispatch_pystub(
|
||||
op._schema.name,
|
||||
op._schema.overload_name)
|
||||
if not maybe_pystub:
|
||||
raise RuntimeError(
|
||||
f"Operator '{qualname}' was defined in C++ and has a Python "
|
||||
f"abstract impl. In this situation, it is required to have a "
|
||||
f"C++ `m.impl_abstract_pystub` call, but we could not find one."
|
||||
f"Please add a call to `m.impl_abstract_pystub(\"{actual_module_name}\");` "
|
||||
f"to the C++ TORCH_LIBRARY block the operator was "
|
||||
f"defined in.")
|
||||
pystub_module = maybe_pystub[0]
|
||||
if actual_module_name != pystub_module:
|
||||
raise RuntimeError(
|
||||
f"Operator '{qualname}' specified that its python abstract impl "
|
||||
f"is in the Python module '{pystub_module}' but it was actually found "
|
||||
f"in '{actual_module_name}'. Please either move the abstract impl "
|
||||
f"or correct the m.impl_abstract_pystub call.")
|
||||
checked = True
|
||||
return func(*args, **kwargs)
|
||||
return inner
|
||||
|
||||
|
||||
# NOTE [ctx inside the fake implementation]
|
||||
# If a user has an operator with data-dependent output shape, then when writing
|
||||
# a fake implementation they must query the current ctx and use methods on the
|
||||
|
Reference in New Issue
Block a user