fix external codegen kernel error checking (#85029)

Fixes https://github.com/pytorch/pytorch/issues/84987. I followed the repro steps from the issue (changed `empty_symint` to `empty_symint2` and confirmed that and error gets raised.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85029
Approved by: https://github.com/ezyang
This commit is contained in:
Brian Hirsh
2022-09-16 13:04:09 -07:00
committed by PyTorch MergeBot
parent 652707abc0
commit 1838957e6f
2 changed files with 68 additions and 12 deletions

View File

@ -270,30 +270,49 @@ def error_on_missing_kernels(
if full_codegen is None:
full_codegen = []
expected_backend_op_names: List[OperatorName] = (
list(backend_indices[backend_key].index.keys()) + []
if autograd_key is None
else list(backend_indices[autograd_key].index.keys())
indices = [backend_indices[backend_key].index] + (
[] if autograd_key is None else [backend_indices[autograd_key].index]
)
# Quick mapping from each OperatorName used by the external backend
# to its backend kernel name
expected_backend_op_names: Dict[OperatorName, str] = dict(
list(
concatMap(
lambda index: [
(op_name, metadata.kernel) for op_name, metadata in index.items()
],
indices,
)
)
)
expected_backend_native_funcs: List[NativeFunction] = [
f
for f in native_functions
if f.func.name in expected_backend_op_names and f.func.name not in full_codegen
if f.func.name in expected_backend_op_names.keys()
and f.func.name not in full_codegen
]
expected_backend_kernel_name_counts: Dict[str, List[NativeFunction]] = defaultdict(
list
)
for native_f in expected_backend_native_funcs:
expected_backend_kernel_name_counts[dispatcher.name(native_f.func)].append(
native_f
)
expected_backend_kernel_name_counts[
expected_backend_op_names[native_f.func.name]
].append(native_f)
# This just looks for lines containing "foo(", and assumes that the kernel foo has been implemented.
# It might cause false negatives (we won't catch all cases), but that's ok - if we catch a missing kernel
# here, then we get a nicer error message. If we miss it, you get a linker error.
kernel_defn_regex = rf"{class_name}::\s*([\w\d]*)\("
kernel_defn_regex = rf"(.*){class_name}::\s*([\w\d]*)\("
actual_backend_kernel_name_counts = Counter(
re.findall(kernel_defn_regex, backend_defns)
# A bit unwieldy (this could probably be moved into regex),
# but we don't want to include kernel names that come from function calls,
# like "return torch_xla::XLANativeFunctions::empty_strided_symint(...)".
# Easy check is to ignore any lines with colons before the class name.
[
y
for (x, y) in re.findall(kernel_defn_regex, backend_defns)
if not x.endswith(":")
]
)
missing_kernels_err_msg = ""