Update impl_abstract_pystub to be less boilerplatey (#112851)

Summary:
We've made the following changes:
- The new way to use the API is `m.impl_abstract_pystub(module, context)`.
  Every subsequent m.def of an op inside the TORCH_LIBRARY block gives
  the op the `impl_abstract_pystub`.
- Added a mechanism to determine if an operator was defined in Python or C++.
  Library.define in Python appends the op to a global set, which is analogous
  to what we do for tracking Library.impl.
- If someone does `torch.library.impl_abstract` in Python for an operator, then
  we require that it has an `impl_abstract_pystub` specified and we also check
  that the module in the `impl_abstract_pystub` is the same as the module where
  the call to `torch.library.impl_abstract` exists.
- Unfortunately we can't check the "context" (which is the buck target on
  buck-based systems) because buck sits above us.

Test Plan: - existing tests

Differential Revision: D50972148

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112851
Approved by: https://github.com/ezyang
This commit is contained in:
Richard Zou
2023-11-07 16:07:42 +00:00
committed by PyTorch MergeBot
parent 9a28a7b498
commit 6ae4e3a8d2
15 changed files with 166 additions and 29 deletions

View File

@ -265,6 +265,15 @@ AbstractImplPyStubsType& abstractImplPyStubsSingleton() {
}
c10::optional<std::pair<const char*, const char*>> Dispatcher::getAbstractImplPyStub(OperatorName op_name) {
std::lock_guard<std::mutex> lock(guard_->mutex);
auto found = abstractImplPyStubsSingleton().find(op_name);
if (found == abstractImplPyStubsSingleton().end()) {
return c10::nullopt;
}
return found->second;
}
RegistrationHandleRAII Dispatcher::registerAbstractImplPyStub(
const OperatorName& op_name,
const char* pymodule,
@ -305,9 +314,9 @@ void Dispatcher::throwIfHasAbstractImplPyStub(OperatorName op_name) {
interpreter != nullptr,
op_name,
": while attempting to run this operator with Meta Tensors: "
"the abstract impl for this operator (necessary for Meta Tensors) "
"was declared to exist in the Python module ", pymodule,
" but Python is not available.");
"Either there is no meta kernel for this operator, or it is located "
"in the python module ", pymodule, " which is not available "
"because Python isn't available.")
(*interpreter)->throw_abstract_impl_not_imported_error(toString(op_name), pymodule, context);
}

View File

@ -234,6 +234,8 @@ public:
*/
void throwIfHasAbstractImplPyStub(OperatorName op_name);
c10::optional<std::pair<const char*, const char*>> getAbstractImplPyStub(OperatorName op_name);
/**
* Register a new operator by name.
*/

View File

@ -128,6 +128,14 @@ Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name
}
switch (rv) {
case _RegisterOrVerify::REGISTER:
if (impl_abstract_pystub_.has_value()) {
registrars_.emplace_back(
c10::Dispatcher::singleton().registerAbstractImplPyStub(
schema.operator_name(),
impl_abstract_pystub_->first,
impl_abstract_pystub_->second)
);
}
registrars_.emplace_back(
c10::Dispatcher::singleton().registerDef(
std::move(schema),

View File

@ -1,14 +1,13 @@
import torch
import torch._custom_ops as library
from model import get_custom_op_library_path
torch.ops.load_library(get_custom_op_library_path())
@library.impl_abstract("custom::nonzero")
@torch.library.impl_abstract("custom::nonzero")
def nonzero_abstract(x):
n = x.dim()
ctx = library.get_ctx()
ctx = torch.library.get_ctx()
nnz = ctx.create_unbacked_symint()
shape = [nnz, n]
return x.new_empty(shape, dtype=torch.long)

View File

@ -1,10 +1,9 @@
import torch
import torch._custom_ops as library
from model import get_custom_op_library_path
torch.ops.load_library(get_custom_op_library_path())
@library.impl_abstract("custom::sin")
@torch.library.impl_abstract("custom::sin")
def sin_abstract(x):
return torch.empty_like(x)

View File

@ -73,14 +73,22 @@ torch::Tensor custom_sin(torch::Tensor x) {
TORCH_LIBRARY_FRAGMENT(custom, m) {
m.impl_abstract_pystub("my_custom_ops2");
m.def("op", custom_op);
m.def("op2", custom_op2);
m.def("op_with_defaults(Tensor tensor, float scalar = 1, int repeat = 1) -> Tensor[]", custom_op);
m.def("op_with_autograd(Tensor var1, int mul, Tensor var2, Tensor? var3=None) -> Tensor", custom_op_with_autograd);
m.def("sin(Tensor x) -> Tensor");
m.impl_abstract_pystub("sin", "my_custom_ops2");
m.def("cos(Tensor x) -> Tensor");
}
TORCH_LIBRARY_FRAGMENT(custom, m) {
m.impl_abstract_pystub("my_custom_ops");
m.def("nonzero(Tensor x) -> Tensor");
m.impl_abstract_pystub("nonzero", "my_custom_ops");
}
TORCH_LIBRARY_FRAGMENT(custom, m) {
m.def("tan(Tensor x) -> Tensor");
}
TORCH_LIBRARY_IMPL(custom, CPU, m) {

View File

@ -0,0 +1,17 @@
import torch
from model import get_custom_op_library_path
torch.ops.load_library(get_custom_op_library_path())
# NB: The impl_abstract_pystub for cos actually
# specifies it should live in the my_custom_ops2 module.
@torch.library.impl_abstract("custom::cos")
def cos_abstract(x):
return torch.empty_like(x)
# NB: There is no impl_abstract_pystub for tan
@torch.library.impl_abstract("custom::tan")
def tan_abstract(x):
return torch.empty_like(x)

View File

@ -10,6 +10,7 @@ from torch import ops
from model import Model, get_custom_op_library_path
from torch.testing._internal.common_utils import TestCase, run_tests
torch.ops.import_module("pointwise")
class TestCustomOperators(TestCase):
def setUp(self):
@ -19,6 +20,16 @@ class TestCustomOperators(TestCase):
def test_custom_library_is_loaded(self):
self.assertIn(self.library_path, ops.loaded_libraries)
def test_op_with_no_abstract_impl_pystub(self):
x = torch.randn(3, device='meta')
with self.assertRaisesRegex(RuntimeError, "pointwise"):
torch.ops.custom.tan(x)
def test_op_with_incorrect_abstract_impl_pystub(self):
x = torch.randn(3, device='meta')
with self.assertRaisesRegex(RuntimeError, "pointwise"):
torch.ops.custom.cos(x)
def test_abstract_impl_pystub_faketensor(self):
from functorch import make_fx
x = torch.randn(3, device='cpu')

View File

@ -1776,6 +1776,20 @@ def forward(self, x_1):
y = self.ns().foo(x)
assert torch.allclose(y, x.sin())
def test_defined_in_python(self):
self.assertFalse(torch.ops.aten.sin.default._defined_in_python)
self.assertFalse(torch.ops.aten.sum.dim_IntList._defined_in_python)
lib = self.lib()
torch.library.define("{self._test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
ns = self.ns()
self.assertTrue(ns.foo.default._defined_in_python)
torch.library.define(
"{self._test_ns}::bar.overload", "(Tensor x) -> Tensor", lib=lib
)
self.assertTrue(ns.bar.overload._defined_in_python)
def _test_impl_device(self, name, types, device):
lib = self.lib()
torch.library.define(f"{self.test_ns}::{name}", "(Tensor x) -> Tensor", lib=lib)

View File

@ -1451,6 +1451,7 @@ def _dispatch_key_for_device(device_type: str) -> str: ...
def _parse_dispatch_key(key: str) -> Optional[DispatchKey]: ...
def _dispatch_key_parse(dispatch: _dispatchkey) -> DispatchKey: ...
def _dispatch_num_backends() -> _int: ...
def _dispatch_pystub(name: str, overload: str) -> Optional[Tuple[str, str]]: ...
def _dispatch_is_alias_key(dispatch: _dispatchkey) -> _bool: ...
def _functionality_to_backend_keys(dispatch: _dispatchkey) -> List[DispatchKey]: ...
def _functionalization_reapply_views_tls() -> _bool: ...

View File

@ -487,6 +487,9 @@ class OpOverload(OperatorBase):
self.__qualname__ = self._name
self.__annotations__ = {}
# If the OpOverload was constructed from a Library.def in Python.
self._defined_in_python = self.__qualname__ in torch.library._defs
# Logic replicated from aten/src/ATen/native/MathBitsFallback.h
is_write = None
for a in self._schema.arguments:

View File

@ -1,5 +1,6 @@
import logging
import os
import sys
import tempfile
from typing import Any, Dict
@ -48,11 +49,16 @@ def resolve_library_path(path: str) -> str:
def throw_abstract_impl_not_imported_error(opname, module, context):
raise NotImplementedError(
f"{opname}: We could not find the abstract impl for this operator. "
f"The operator specified that you need to import the '{module}' Python "
f"module to load the abstract impl. {context}"
)
if module in sys.modules:
raise NotImplementedError(
f"{opname}: We could not find the abstract impl for this operator. "
)
else:
raise NotImplementedError(
f"{opname}: We could not find the abstract impl for this operator. "
f"The operator specified that you may need to import the '{module}' "
f"Python module to load the abstract impl. {context}"
)
# Meta only, see

View File

@ -773,6 +773,10 @@ void initDispatchBindings(PyObject* module) {
m.def(
"_dispatch_is_main_interpreter", []() { return isMainPyInterpreter(); });
m.def("_dispatch_pystub", [](const char* name, const char* overload) {
return c10::Dispatcher::singleton().getAbstractImplPyStub(
c10::OperatorName(name, overload));
});
m.def("_replace_", [](const at::Tensor& a, const at::Tensor& b) {
return at::functionalization::impl::replace_(a, b);

View File

@ -93,6 +93,8 @@ enum class _RegisterOrVerify { REGISTER, VERIFY };
template <class CurClass>
class class_;
#define HAS_IMPL_ABSTRACT_PYSTUB
/// Represents a C++ function that implements an operator. Most users won't
/// interact directly with this class, except via error messages: the
/// constructors this function define the set of permissible "function"-like
@ -606,19 +608,16 @@ class TORCH_API Library final {
return _def(std::move(s), nullptr, tags, rv);
}
/// Declares that an operator (given by name) has an abstract impl in a
/// Python module (pymodule). If the abstract impl was not yet imported,
/// we will warn about it.
/// Declares that for all operators that are subsequently def'ed, their
/// abstract impls may be found in the given Python module (pymodule).
/// This registers some help text that is used if the abstract impl
/// cannot be found.
///
/// Args:
/// - name: the name of the operator
/// - pymodule: the python module
/// - context: We may include this in the error message.
Library& impl_abstract_pystub(const char* name, const char* pymodule, const char* context = "") {
at::OperatorName opname = _parseNameForLib(name);
registrars_.emplace_back(
c10::Dispatcher::singleton().registerAbstractImplPyStub(opname, pymodule, context)
);
Library& impl_abstract_pystub(const char* pymodule, const char* context = "") {
impl_abstract_pystub_ = {pymodule, context};
return *this;
}
@ -840,6 +839,7 @@ class TORCH_API Library final {
Kind kind_;
c10::optional<std::string> ns_;
c10::optional<c10::DispatchKey> dispatch_key_;
c10::optional<std::pair<const char*, const char*>> impl_abstract_pystub_;
const char* file_;
uint32_t line_;

View File

@ -4,6 +4,7 @@ import traceback
import torch
import weakref
import functools
import inspect
import re
__all__ = [
@ -20,6 +21,7 @@ __all__ = [
# This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid
# libraries calling into kernels not intended to be called.
_impls: Set[str] = set()
_defs: Set[str] = set()
# prim is reserved by TorchScript interpreter
_reserved_namespaces = ['prim']
@ -59,6 +61,7 @@ class Library:
filename, lineno = frame.filename, frame.lineno
self.m: Optional[Any] = torch._C._dispatch_library(kind, ns, dispatch_key, filename, lineno)
self.ns = ns
self._op_defs: Set[str] = set()
self._op_impls: Set[str] = set()
self._registration_handles: List["torch._library.utils.RegistrationHandle"] = []
self.kind = kind
@ -67,7 +70,7 @@ class Library:
# Python __del__ can lead to weird things (globals and locals may already
# be gone when __del__ actually gets called!). finalizers help the
# situation because it lets us capture references and keeps them alive
weakref.finalize(self, _del_library, _impls, self._op_impls, self._registration_handles)
weakref.finalize(self, _del_library, _impls, self._op_impls, _defs, self._op_defs, self._registration_handles)
def __repr__(self):
return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>"
@ -99,7 +102,11 @@ class Library:
assert self.m is not None
if isinstance(tags, torch.Tag):
tags = (tags,)
return self.m.define(schema, alias_analysis, tuple(tags))
result = self.m.define(schema, alias_analysis, tuple(tags))
qualname = self.ns + "::" + schema.split("(")[0]
self._op_defs.add(qualname)
_defs.add(qualname)
return result
def impl(self, op_name, fn, dispatch_key=''):
r'''Registers the function implementation for an operator defined in the library.
@ -169,8 +176,9 @@ class Library:
self._registration_handles.clear()
def _del_library(captured_impls, op_impls, registration_handles):
def _del_library(captured_impls, op_impls, captured_defs, op_defs, registration_handles):
captured_impls -= op_impls
captured_defs -= op_defs
for handle in registration_handles:
handle.destroy()
@ -409,12 +417,21 @@ def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
>>> return torch.tensor(res, device=x.device)
"""
source = torch._library.utils.get_source(_stacklevel + 1)
frame = inspect.stack()[_stacklevel]
caller_module = inspect.getmodule(frame[0])
# Can be none if you call impl_abstract from somewhere there isn't a module
# (e.g. __main__)
caller_module_name = None if caller_module is None else caller_module.__name__
def inner(func):
entry = torch._library.simple_registry.singleton.find(qualname)
handle = entry.abstract_impl.register(func, source)
if caller_module_name is not None:
func_to_register = _check_pystubs_once(func, qualname, caller_module_name)
else:
func_to_register = func
handle = entry.abstract_impl.register(func_to_register, source)
if lib is not None:
lib._registration_handles.append(handle)
return func
@ -424,6 +441,45 @@ def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
return inner(func)
# If the op was defined in C++, then we want to make sure there was an
# m.impl_abstract_pystub(module, ...) call and that the module is the
# same as the module that called torch.library.impl_abstract.
def _check_pystubs_once(func, qualname, actual_module_name):
checked = False
def inner(*args, **kwargs):
nonlocal checked
if checked:
return func(*args, **kwargs)
op = torch._library.utils.lookup_op(qualname)
if op._defined_in_python:
checked = True
return func(*args, **kwargs)
maybe_pystub = torch._C._dispatch_pystub(
op._schema.name,
op._schema.overload_name)
if not maybe_pystub:
raise RuntimeError(
f"Operator '{qualname}' was defined in C++ and has a Python "
f"abstract impl. In this situation, it is required to have a "
f"C++ `m.impl_abstract_pystub` call, but we could not find one."
f"Please add a call to `m.impl_abstract_pystub(\"{actual_module_name}\");` "
f"to the C++ TORCH_LIBRARY block the operator was "
f"defined in.")
pystub_module = maybe_pystub[0]
if actual_module_name != pystub_module:
raise RuntimeError(
f"Operator '{qualname}' specified that its python abstract impl "
f"is in the Python module '{pystub_module}' but it was actually found "
f"in '{actual_module_name}'. Please either move the abstract impl "
f"or correct the m.impl_abstract_pystub call.")
checked = True
return func(*args, **kwargs)
return inner
# NOTE [ctx inside the fake implementation]
# If a user has an operator with data-dependent output shape, then when writing
# a fake implementation they must query the current ctx and use methods on the