mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torchgen] Fix multiple backends with custom namespace (#82133)
Summary: Some quantized operators needs `QuantizedCPU` backend, due to an issue in namespace checking, currently if we have two backends as well as a custom namespaces in native function, codegen will hit assertion error. This PR fixes this issue The root cause is that codegen right now asserts that a native function should only have one namespace. The current behavior is that If a native function is not found in a `BackendIndex`, we will use default namespace for that backend, for fallback kernels. However that default namespace may not be listed in the yaml file and it should not be counted when checking if we have two different namespaces for that backend. In our error case, we have 2 `BackendIndex`, one for `QuantizedCPU` and one for `CPU`. The native function doesn't have a kernel in `QuantizedCPU` but we still use a default namespace (`at::native`) for it. Since we have a custom namespace for dispatch key `CPU`, we ran into the assertion error. This PR changes the assertion criteria. We only error out if a namespace has two or more kernels and they have two or more different namespaces. Test Plan: rely on newly added unit test Differential Revision: D38101345 Pull Request resolved: https://github.com/pytorch/pytorch/pull/82133 Approved by: https://github.com/iseeyuan
This commit is contained in:
committed by
PyTorch MergeBot
parent
3ca78a4c75
commit
301fe8c27d
@ -261,3 +261,16 @@ def define_tools_targets(
|
||||
":gen_operators_yaml_lib",
|
||||
],
|
||||
)
|
||||
|
||||
python_test(
|
||||
name = "test_codegen",
|
||||
srcs = [
|
||||
"test/test_codegen.py",
|
||||
],
|
||||
contacts = contacts,
|
||||
visibility = ["PUBLIC"],
|
||||
deps = [
|
||||
torchgen_deps,
|
||||
":autograd",
|
||||
],
|
||||
)
|
||||
|
@ -1,11 +1,22 @@
|
||||
import dataclasses
|
||||
import typing
|
||||
import unittest
|
||||
from typing import Dict
|
||||
|
||||
import torchgen.model
|
||||
|
||||
from tools.autograd import gen_autograd_functions, load_derivatives
|
||||
from torchgen.gen import get_native_function_schema_registrations
|
||||
from torchgen.gen import (
|
||||
get_native_function_declarations,
|
||||
get_native_function_schema_registrations,
|
||||
)
|
||||
from torchgen.model import (
|
||||
BackendIndex,
|
||||
BackendMetadata,
|
||||
DispatchKey,
|
||||
NativeFunction,
|
||||
OperatorName,
|
||||
)
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
|
||||
@ -198,6 +209,67 @@ TORCH_LIBRARY(custom, m) {
|
||||
)
|
||||
|
||||
|
||||
class TestGenNativeFunctionDeclaration(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml(
|
||||
{"func": "op_1() -> bool", "dispatch": {"CPU": "kernel_1"}},
|
||||
loc=torchgen.model.Location(__file__, 1),
|
||||
valid_tags=set(),
|
||||
)
|
||||
self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml(
|
||||
{
|
||||
"func": "op_2() -> bool",
|
||||
"dispatch": {"CPU": "kernel_2", "QuantizedCPU": "custom::kernel_3"},
|
||||
},
|
||||
loc=torchgen.model.Location(__file__, 1),
|
||||
valid_tags=set(),
|
||||
)
|
||||
|
||||
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = {
|
||||
DispatchKey.CPU: {},
|
||||
DispatchKey.QuantizedCPU: {},
|
||||
}
|
||||
BackendIndex.grow_index(backend_indices, op_1_backend_index)
|
||||
BackendIndex.grow_index(backend_indices, op_2_backend_index)
|
||||
self.backend_indices = {
|
||||
k: BackendIndex(
|
||||
dispatch_key=k,
|
||||
use_out_as_primary=True,
|
||||
external=False,
|
||||
device_guard=False,
|
||||
index=backend_indices[k],
|
||||
)
|
||||
for k in backend_indices
|
||||
}
|
||||
|
||||
def test_native_function_declaration_1_op_2_ns_error(self) -> None:
|
||||
with self.assertRaises(AssertionError):
|
||||
get_native_function_declarations(
|
||||
grouped_native_functions=[
|
||||
self.op_1_native_function,
|
||||
self.op_2_native_function,
|
||||
],
|
||||
backend_indices=self.backend_indices,
|
||||
)
|
||||
|
||||
def test_native_function_declaration_1_op_1_ns_valid(self) -> None:
|
||||
self.assertIsInstance(self.op_1_native_function, NativeFunction)
|
||||
declaration = get_native_function_declarations(
|
||||
grouped_native_functions=[
|
||||
self.op_1_native_function,
|
||||
],
|
||||
backend_indices=self.backend_indices,
|
||||
)
|
||||
target = """
|
||||
namespace at {
|
||||
namespace native {
|
||||
TORCH_API bool kernel_1();
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
"""
|
||||
self.assertEqual("\n".join(declaration), target)
|
||||
|
||||
|
||||
# Represents the most basic NativeFunction. Use dataclasses.replace()
|
||||
# to edit for use.
|
||||
DEFAULT_NATIVE_FUNCTION, _ = torchgen.model.NativeFunction.from_yaml(
|
||||
|
@ -1366,17 +1366,18 @@ def get_native_function_declarations(
|
||||
newline = "\n"
|
||||
for f in grouped_native_functions:
|
||||
native_function_namespaces = set()
|
||||
for backend_idx in backend_indices.values():
|
||||
dispatch_keys = set()
|
||||
for dispatch_key, backend_idx in backend_indices.items():
|
||||
backend_metadata = backend_idx.get_kernel(f)
|
||||
namespace = (
|
||||
backend_metadata.cpp_namespace
|
||||
if backend_metadata
|
||||
else DEFAULT_KERNEL_NAMESPACE
|
||||
)
|
||||
native_function_namespaces.add(namespace)
|
||||
if backend_metadata:
|
||||
namespace = backend_metadata.cpp_namespace
|
||||
dispatch_keys.add(dispatch_key)
|
||||
native_function_namespaces.add(namespace)
|
||||
else:
|
||||
namespace = DEFAULT_KERNEL_NAMESPACE
|
||||
assert (
|
||||
len(native_function_namespaces) == 1
|
||||
), "Codegen only supports one namespace per operator."
|
||||
len(native_function_namespaces) <= 1
|
||||
), f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}"
|
||||
ns_grouped_kernels[namespace].extend(
|
||||
dest.compute_native_function_declaration(f, backend_idx)
|
||||
)
|
||||
|
@ -1115,7 +1115,7 @@ class BackendIndex:
|
||||
elif isinstance(g, NativeFunctionsGroup):
|
||||
f = self.primary(g)
|
||||
else:
|
||||
assert_never(f)
|
||||
assert_never(g)
|
||||
if f.func.name not in self.index:
|
||||
return None
|
||||
return self.index[f.func.name]
|
||||
|
Reference in New Issue
Block a user