mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fix external codegen kernel error checking (#85029)
Fixes https://github.com/pytorch/pytorch/issues/84987. I followed the repro steps from the issue (changed `empty_symint` to `empty_symint2` and confirmed that and error gets raised. Pull Request resolved: https://github.com/pytorch/pytorch/pull/85029 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
652707abc0
commit
1838957e6f
@ -270,30 +270,49 @@ def error_on_missing_kernels(
|
||||
if full_codegen is None:
|
||||
full_codegen = []
|
||||
|
||||
expected_backend_op_names: List[OperatorName] = (
|
||||
list(backend_indices[backend_key].index.keys()) + []
|
||||
if autograd_key is None
|
||||
else list(backend_indices[autograd_key].index.keys())
|
||||
indices = [backend_indices[backend_key].index] + (
|
||||
[] if autograd_key is None else [backend_indices[autograd_key].index]
|
||||
)
|
||||
# Quick mapping from each OperatorName used by the external backend
|
||||
# to its backend kernel name
|
||||
expected_backend_op_names: Dict[OperatorName, str] = dict(
|
||||
list(
|
||||
concatMap(
|
||||
lambda index: [
|
||||
(op_name, metadata.kernel) for op_name, metadata in index.items()
|
||||
],
|
||||
indices,
|
||||
)
|
||||
)
|
||||
)
|
||||
expected_backend_native_funcs: List[NativeFunction] = [
|
||||
f
|
||||
for f in native_functions
|
||||
if f.func.name in expected_backend_op_names and f.func.name not in full_codegen
|
||||
if f.func.name in expected_backend_op_names.keys()
|
||||
and f.func.name not in full_codegen
|
||||
]
|
||||
expected_backend_kernel_name_counts: Dict[str, List[NativeFunction]] = defaultdict(
|
||||
list
|
||||
)
|
||||
for native_f in expected_backend_native_funcs:
|
||||
expected_backend_kernel_name_counts[dispatcher.name(native_f.func)].append(
|
||||
native_f
|
||||
)
|
||||
expected_backend_kernel_name_counts[
|
||||
expected_backend_op_names[native_f.func.name]
|
||||
].append(native_f)
|
||||
|
||||
# This just looks for lines containing "foo(", and assumes that the kernel foo has been implemented.
|
||||
# It might cause false negatives (we won't catch all cases), but that's ok - if we catch a missing kernel
|
||||
# here, then we get a nicer error message. If we miss it, you get a linker error.
|
||||
kernel_defn_regex = rf"{class_name}::\s*([\w\d]*)\("
|
||||
kernel_defn_regex = rf"(.*){class_name}::\s*([\w\d]*)\("
|
||||
actual_backend_kernel_name_counts = Counter(
|
||||
re.findall(kernel_defn_regex, backend_defns)
|
||||
# A bit unwieldy (this could probably be moved into regex),
|
||||
# but we don't want to include kernel names that come from function calls,
|
||||
# like "return torch_xla::XLANativeFunctions::empty_strided_symint(...)".
|
||||
# Easy check is to ignore any lines with colons before the class name.
|
||||
[
|
||||
y
|
||||
for (x, y) in re.findall(kernel_defn_regex, backend_defns)
|
||||
if not x.endswith(":")
|
||||
]
|
||||
)
|
||||
|
||||
missing_kernels_err_msg = ""
|
||||
|
Reference in New Issue
Block a user