Better error messages for impl_abstract_pystub (#120959)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120959
Approved by: https://github.com/drisspg
This commit is contained in:
rzou
2024-03-01 11:17:16 -08:00
committed by PyTorch MergeBot
parent ce2903080c
commit 3ef0befdc9
6 changed files with 24 additions and 10 deletions

View File

@ -42,7 +42,7 @@ class TestCustomOperators(TestCase):
def f(x):
return torch.ops.custom.asin(x)
with self.assertRaisesRegex(RuntimeError, r'unsupported operator: .* \(you may need to `import nonexistent`'):
with self.assertRaisesRegex(RuntimeError, r'unsupported operator: .* you may need to `import nonexistent`'):
f(x)
def test_abstract_impl_pystub_faketensor(self):
@ -64,7 +64,7 @@ def forward(self, arg0_1):
def test_abstract_impl_pystub_meta(self):
x = torch.randn(3, device="meta")
self.assertNotIn("my_custom_ops2", sys.modules.keys())
with self.assertRaisesRegex(NotImplementedError, r"import the 'my_custom_ops2'"):
with self.assertRaisesRegex(NotImplementedError, r"'my_custom_ops2'"):
y = torch.ops.custom.sin.default(x)
torch.ops.import_module("my_custom_ops2")
y = torch.ops.custom.sin.default(x)

View File

@ -1392,6 +1392,7 @@ class Generator:
class _DispatchOperatorHandle:
def schema(self) -> FunctionSchema: ...
def debug(self) -> str: ...
class _DispatchModule:
def def_(self, schema: str, alias: str = "") -> _DispatchModule: ...

View File

@ -1688,7 +1688,11 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
)
if maybe_pystub is not None:
module, ctx = maybe_pystub
import_suggestion = f"you may need to `import {module}` ({ctx}) for support, otherwise "
import_suggestion = (
f"It's possible that the support was implemented in "
f"module `{module}` and you may need to `import {module}`"
f"({ctx}), otherwise "
)
unimplemented(
f"unsupported operator: {cause.func} ({import_suggestion}see "
"https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0"

View File

@ -614,6 +614,11 @@ class OpOverload(OperatorBase):
def namespace(self):
return self._schema.name.split("::")[0]
def _handle(self):
return torch._C._dispatch_find_schema_or_throw(
self._schema.name, self._schema.overload_name
)
def decompose(self, *args, **kwargs):
dk = torch._C.DispatchKey.CompositeImplicitAutograd
if dk in self.py_kernels:

View File

@ -241,7 +241,8 @@ void initDispatchBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
py::class_<c10::OperatorHandle>(m, "_DispatchOperatorHandle")
.def("schema", &c10::OperatorHandle::schema);
.def("schema", &c10::OperatorHandle::schema)
.def("debug", &c10::OperatorHandle::debug);
m.def("_dispatch_call_boxed", &ophandle_call_boxed);

View File

@ -494,20 +494,23 @@ def _check_pystubs_once(func, qualname, actual_module_name):
op._schema.name,
op._schema.overload_name)
if not maybe_pystub:
namespace = op.namespace
cpp_filename = op._handle().debug()
raise RuntimeError(
f"Operator '{qualname}' was defined in C++ and has a Python "
f"abstract impl. In this situation, it is required to have a "
f"C++ `m.impl_abstract_pystub` call, but we could not find one."
f"Please add a call to `m.impl_abstract_pystub(\"{actual_module_name}\");` "
f"to the C++ TORCH_LIBRARY block the operator was "
f"defined in.")
f"abstract impl. In this situation, we require there to also be a "
f"companion C++ `m.impl_abstract_pystub(\"{actual_module_name}\")` "
f"call, but we could not find one. Please add that to "
f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the "
f"operator was registered in ({cpp_filename})")
pystub_module = maybe_pystub[0]
if actual_module_name != pystub_module:
cpp_filename = op._handle().debug()
raise RuntimeError(
f"Operator '{qualname}' specified that its python abstract impl "
f"is in the Python module '{pystub_module}' but it was actually found "
f"in '{actual_module_name}'. Please either move the abstract impl "
f"or correct the m.impl_abstract_pystub call.")
f"or correct the m.impl_abstract_pystub call ({cpp_filename})")
checked = True
return func(*args, **kwargs)
return inner