[torchgen] Fix multiple backends with custom namespace (#82133)

Summary:
Some quantized operators needs `QuantizedCPU` backend, due to an issue in namespace checking, currently if we have two backends as well as a custom namespaces in native function, codegen will hit assertion error. This PR fixes this issue

The root cause is that codegen right now asserts that a native function should only have one namespace. The current behavior is that If a native function is not found in a `BackendIndex`, we will use default namespace for that backend, for fallback kernels. However that default namespace may not be listed in the yaml file and it should not be counted when checking if we have two different namespaces for that backend. In our error case, we have 2 `BackendIndex`, one for `QuantizedCPU` and one for `CPU`. The native function doesn't have a kernel in `QuantizedCPU` but we still use a default namespace (`at::native`) for it. Since we have a custom namespace for dispatch key `CPU`, we ran into the assertion error.

This PR changes the assertion criteria. We only error out if a namespace has two or more kernels and they have two or more different namespaces.

Test Plan: rely on newly added unit test

Differential Revision: D38101345

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82133
Approved by: https://github.com/iseeyuan
This commit is contained in:
Mengwei Liu
2022-07-29 22:53:58 +00:00
committed by PyTorch MergeBot
parent 3ca78a4c75
commit 301fe8c27d
4 changed files with 97 additions and 11 deletions

View File

@ -261,3 +261,16 @@ def define_tools_targets(
":gen_operators_yaml_lib",
],
)
python_test(
name = "test_codegen",
srcs = [
"test/test_codegen.py",
],
contacts = contacts,
visibility = ["PUBLIC"],
deps = [
torchgen_deps,
":autograd",
],
)

View File

@ -1,11 +1,22 @@
import dataclasses
import typing
import unittest
from typing import Dict
import torchgen.model
from tools.autograd import gen_autograd_functions, load_derivatives
from torchgen.gen import get_native_function_schema_registrations
from torchgen.gen import (
get_native_function_declarations,
get_native_function_schema_registrations,
)
from torchgen.model import (
BackendIndex,
BackendMetadata,
DispatchKey,
NativeFunction,
OperatorName,
)
from torchgen.selective_build.selector import SelectiveBuilder
@ -198,6 +209,67 @@ TORCH_LIBRARY(custom, m) {
)
class TestGenNativeFunctionDeclaration(unittest.TestCase):
def setUp(self) -> None:
self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml(
{"func": "op_1() -> bool", "dispatch": {"CPU": "kernel_1"}},
loc=torchgen.model.Location(__file__, 1),
valid_tags=set(),
)
self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml(
{
"func": "op_2() -> bool",
"dispatch": {"CPU": "kernel_2", "QuantizedCPU": "custom::kernel_3"},
},
loc=torchgen.model.Location(__file__, 1),
valid_tags=set(),
)
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = {
DispatchKey.CPU: {},
DispatchKey.QuantizedCPU: {},
}
BackendIndex.grow_index(backend_indices, op_1_backend_index)
BackendIndex.grow_index(backend_indices, op_2_backend_index)
self.backend_indices = {
k: BackendIndex(
dispatch_key=k,
use_out_as_primary=True,
external=False,
device_guard=False,
index=backend_indices[k],
)
for k in backend_indices
}
def test_native_function_declaration_1_op_2_ns_error(self) -> None:
with self.assertRaises(AssertionError):
get_native_function_declarations(
grouped_native_functions=[
self.op_1_native_function,
self.op_2_native_function,
],
backend_indices=self.backend_indices,
)
def test_native_function_declaration_1_op_1_ns_valid(self) -> None:
self.assertIsInstance(self.op_1_native_function, NativeFunction)
declaration = get_native_function_declarations(
grouped_native_functions=[
self.op_1_native_function,
],
backend_indices=self.backend_indices,
)
target = """
namespace at {
namespace native {
TORCH_API bool kernel_1();
} // namespace native
} // namespace at
"""
self.assertEqual("\n".join(declaration), target)
# Represents the most basic NativeFunction. Use dataclasses.replace()
# to edit for use.
DEFAULT_NATIVE_FUNCTION, _ = torchgen.model.NativeFunction.from_yaml(

View File

@ -1366,17 +1366,18 @@ def get_native_function_declarations(
newline = "\n"
for f in grouped_native_functions:
native_function_namespaces = set()
for backend_idx in backend_indices.values():
dispatch_keys = set()
for dispatch_key, backend_idx in backend_indices.items():
backend_metadata = backend_idx.get_kernel(f)
namespace = (
backend_metadata.cpp_namespace
if backend_metadata
else DEFAULT_KERNEL_NAMESPACE
)
native_function_namespaces.add(namespace)
if backend_metadata:
namespace = backend_metadata.cpp_namespace
dispatch_keys.add(dispatch_key)
native_function_namespaces.add(namespace)
else:
namespace = DEFAULT_KERNEL_NAMESPACE
assert (
len(native_function_namespaces) == 1
), "Codegen only supports one namespace per operator."
len(native_function_namespaces) <= 1
), f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}"
ns_grouped_kernels[namespace].extend(
dest.compute_native_function_declaration(f, backend_idx)
)

View File

@ -1115,7 +1115,7 @@ class BackendIndex:
elif isinstance(g, NativeFunctionsGroup):
f = self.primary(g)
else:
assert_never(f)
assert_never(g)
if f.func.name not in self.index:
return None
return self.index[f.func.name]