mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
There are four core ATen ops with Composite Implicit Autograd (CIA) dispatch: upsample_bilinear2d.vec, upsample_nearest2d.vec, avg_pool1d, and adaptive_avg_pool1d. Op variant auto-generation is currently skipped for CIA ops. In preparation to disable the decompositions for upsample ops by default in export, we need to generate out variants for these ops. This change enables autogen for core-tagged CIA ops, which enables generation of upsample_bilinear2d.vec_out and upsample_nearest2d.vec_out. Test Plan: Added a new test test_functional_variant_autogen_out_variant_core to cover this case in test_codegen.py. Confirmed that upsample_bilinear2d.vec_out and upsample_nearest2d.vec_out op overloads are registered (they were previously not available). Differential Revision: D66590257 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141797 Approved by: https://github.com/larryliu0820
536 lines
20 KiB
Python
536 lines
20 KiB
Python
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import typing
|
|
import unittest
|
|
from collections import defaultdict
|
|
|
|
import yaml
|
|
from tools.autograd import gen_autograd_functions, load_derivatives
|
|
|
|
from torchgen import dest
|
|
from torchgen.api.types import CppSignatureGroup, DispatcherSignature
|
|
from torchgen.context import native_function_manager
|
|
from torchgen.gen import (
|
|
get_native_function_declarations,
|
|
get_native_function_schema_registrations,
|
|
LineLoader,
|
|
static_dispatch,
|
|
)
|
|
from torchgen.model import (
|
|
BackendIndex,
|
|
BackendMetadata,
|
|
DispatchKey,
|
|
FunctionSchema,
|
|
Location,
|
|
NativeFunction,
|
|
OperatorName,
|
|
)
|
|
from torchgen.native_function_generation import add_generated_native_functions
|
|
from torchgen.selective_build.selector import SelectiveBuilder
|
|
|
|
|
|
class TestCreateDerivative(unittest.TestCase):
|
|
def test_named_grads(self) -> None:
|
|
schema = FunctionSchema.parse(
|
|
"func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
|
|
)
|
|
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
|
|
|
derivative = load_derivatives.create_derivative(
|
|
native_function,
|
|
formula="func_backward(grad_x, grad_y)",
|
|
var_names=(),
|
|
available_named_gradients=["grad_x", "grad_y"],
|
|
)
|
|
self.assertSetEqual(derivative.named_gradients, {"grad_x", "grad_y"})
|
|
|
|
def test_non_differentiable_output(self) -> None:
|
|
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
|
|
schema = FunctionSchema.parse(specification)
|
|
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
|
|
|
_, differentiability_info = load_derivatives.create_differentiability_info(
|
|
defn_dict={
|
|
"name": specification,
|
|
"dispatch": {"Default": {"a": "grads[0]", "b": "grads[2]"}},
|
|
},
|
|
functions_by_signature={schema.signature(): [native_function]},
|
|
functions_by_schema={specification: native_function},
|
|
op_counter=typing.Counter[str](),
|
|
used_dispatch_keys=set(),
|
|
)
|
|
|
|
self.assertSequenceEqual(
|
|
differentiability_info["Default"].available_named_gradients,
|
|
# grad_y is not present because y is a
|
|
# bool and thus not differentiable.
|
|
["grad_x", "grad_z"],
|
|
)
|
|
|
|
def test_indexed_grads(self) -> None:
|
|
schema = FunctionSchema.parse(
|
|
"func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
|
|
)
|
|
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
|
|
|
derivative = load_derivatives.create_derivative(
|
|
native_function,
|
|
formula="func_backward(grads[0], grads[1])",
|
|
var_names=(),
|
|
available_named_gradients=["grad_x", "grad_y"],
|
|
)
|
|
self.assertSetEqual(derivative.named_gradients, set())
|
|
|
|
def test_named_grads_and_indexed_grads(self) -> None:
|
|
specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
|
|
schema = FunctionSchema.parse(specification)
|
|
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, 'illegally mixes use of "grad_RETURN_NAME"'
|
|
):
|
|
load_derivatives.create_differentiability_info(
|
|
defn_dict={
|
|
"name": specification,
|
|
# Uh-oh, the derivatives reference gradients by
|
|
# name and by index.
|
|
"dispatch": {
|
|
"Default": {
|
|
"a": "grad_x",
|
|
"b": "grads[1]",
|
|
}
|
|
},
|
|
},
|
|
functions_by_signature={schema.signature(): [native_function]},
|
|
functions_by_schema={specification: native_function},
|
|
op_counter=typing.Counter[str](),
|
|
used_dispatch_keys=set(),
|
|
)
|
|
|
|
|
|
class TestGenAutogradFunctions(unittest.TestCase):
|
|
def test_non_differentiable_output_invalid_type(self) -> None:
|
|
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
|
|
schema = FunctionSchema.parse(specification)
|
|
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
|
|
|
_, differentiability_info = load_derivatives.create_differentiability_info(
|
|
defn_dict={
|
|
"name": specification,
|
|
"dispatch": {
|
|
"Default": {
|
|
"a": "grad_x",
|
|
"b": "grad_z",
|
|
}
|
|
},
|
|
},
|
|
functions_by_signature={schema.signature(): [native_function]},
|
|
functions_by_schema={specification: native_function},
|
|
op_counter=typing.Counter[str](),
|
|
used_dispatch_keys=set(),
|
|
)
|
|
definition = gen_autograd_functions.process_function(
|
|
differentiability_info["Default"],
|
|
gen_autograd_functions.FUNCTION_DEFINITION,
|
|
)
|
|
# grad_z should map to grads[1], not grads[2] because output 1
|
|
# (y) is not differentiable.
|
|
assert "grad_z = grads[2]" not in definition
|
|
assert "grad_z = grads[1]" in definition
|
|
|
|
def test_non_differentiable_output_output_differentiability(self) -> None:
|
|
specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y, Tensor z)"
|
|
schema = FunctionSchema.parse(specification)
|
|
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
|
|
|
_, differentiability_info = load_derivatives.create_differentiability_info(
|
|
defn_dict={
|
|
"name": specification,
|
|
"dispatch": {
|
|
"Default": {
|
|
"a": "grad_x",
|
|
"b": "grad_z",
|
|
},
|
|
"AutogradNestedTensor": {
|
|
"a": "grad_z",
|
|
"b": "grad_x",
|
|
},
|
|
},
|
|
"output_differentiability": [True, False, True],
|
|
},
|
|
functions_by_signature={schema.signature(): [native_function]},
|
|
functions_by_schema={specification: native_function},
|
|
op_counter=typing.Counter[str](),
|
|
used_dispatch_keys=set(),
|
|
)
|
|
default_definition = gen_autograd_functions.process_function(
|
|
differentiability_info["Default"],
|
|
gen_autograd_functions.FUNCTION_DEFINITION,
|
|
)
|
|
# grad_z should map to grads[1], not grads[2] because output 1
|
|
# (y) is not differentiable.
|
|
assert "grad_z = grads[2]" not in default_definition
|
|
assert "grad_z = grads[1]" in default_definition
|
|
|
|
nested_tensor_definition = gen_autograd_functions.process_function(
|
|
differentiability_info["AutogradNestedTensor"],
|
|
gen_autograd_functions.FUNCTION_DEFINITION,
|
|
)
|
|
assert "grad_z = grads[2]" not in nested_tensor_definition
|
|
assert "grad_z = grads[1]" in nested_tensor_definition
|
|
|
|
def test_register_bogus_dispatch_key(self) -> None:
|
|
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
|
|
schema = FunctionSchema.parse(specification)
|
|
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Invalid dispatch key AutogradRandomTensor in derivatives.yaml for",
|
|
):
|
|
load_derivatives.create_differentiability_info(
|
|
defn_dict={
|
|
"name": specification,
|
|
"dispatch": {
|
|
"Default": {
|
|
"a": "grad_x",
|
|
"b": "grad_z",
|
|
},
|
|
"AutogradRandomTensor": {
|
|
"a": "grad_x",
|
|
"b": "grad_z",
|
|
},
|
|
},
|
|
},
|
|
functions_by_signature={schema.signature(): [native_function]},
|
|
functions_by_schema={specification: native_function},
|
|
op_counter=typing.Counter[str](),
|
|
used_dispatch_keys=set(),
|
|
)
|
|
|
|
|
|
class TestGenSchemaRegistration(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
self.selector = SelectiveBuilder.get_nop_selector()
|
|
self.custom_native_function, _ = NativeFunction.from_yaml(
|
|
{"func": "custom::func() -> bool"},
|
|
loc=Location(__file__, 1),
|
|
valid_tags=set(),
|
|
)
|
|
(
|
|
self.fragment_custom_native_function,
|
|
_,
|
|
) = NativeFunction.from_yaml(
|
|
{"func": "quantized_decomposed::func() -> bool"},
|
|
loc=Location(__file__, 1),
|
|
valid_tags=set(),
|
|
)
|
|
|
|
def test_default_namespace_schema_registration_code_valid(self) -> None:
|
|
native_functions = [DEFAULT_NATIVE_FUNCTION]
|
|
registrations, _ = get_native_function_schema_registrations(
|
|
native_functions=native_functions,
|
|
schema_selector=self.selector,
|
|
)
|
|
self.assertEqual(registrations, ['m.def("func() -> bool", {});\n'])
|
|
|
|
def test_custom_namespace_schema_registration_code_valid(self) -> None:
|
|
_, registrations = get_native_function_schema_registrations(
|
|
native_functions=[self.custom_native_function],
|
|
schema_selector=self.selector,
|
|
)
|
|
self.assertEqual(
|
|
registrations,
|
|
"""
|
|
TORCH_LIBRARY(custom, m) {
|
|
m.def("func() -> bool", {});
|
|
|
|
};""",
|
|
)
|
|
|
|
def test_fragment_custom_namespace_schema_registration_code_valid(self) -> None:
|
|
"""Sometimes we want to extend an existing namespace, for example quantized
|
|
namespace, which is already defined in native/quantized/library.cpp
|
|
"""
|
|
_, registrations = get_native_function_schema_registrations(
|
|
native_functions=[self.fragment_custom_native_function],
|
|
schema_selector=self.selector,
|
|
)
|
|
self.assertEqual(
|
|
registrations,
|
|
"""
|
|
TORCH_LIBRARY_FRAGMENT(quantized_decomposed, m) {
|
|
m.def("func() -> bool", {});
|
|
|
|
};""",
|
|
)
|
|
|
|
def test_mixed_namespace_schema_registration_code_valid(self) -> None:
|
|
(
|
|
aten_registrations,
|
|
custom_registrations,
|
|
) = get_native_function_schema_registrations(
|
|
native_functions=[DEFAULT_NATIVE_FUNCTION, self.custom_native_function],
|
|
schema_selector=self.selector,
|
|
)
|
|
self.assertEqual(aten_registrations, ['m.def("func() -> bool", {});\n'])
|
|
self.assertEqual(
|
|
custom_registrations,
|
|
"""
|
|
TORCH_LIBRARY(custom, m) {
|
|
m.def("func() -> bool", {});
|
|
|
|
};""",
|
|
)
|
|
|
|
def test_3_namespaces_schema_registration_code_valid(self) -> None:
|
|
custom2_native_function, _ = NativeFunction.from_yaml(
|
|
{"func": "custom2::func() -> bool"},
|
|
loc=Location(__file__, 1),
|
|
valid_tags=set(),
|
|
)
|
|
(
|
|
aten_registrations,
|
|
custom_registrations,
|
|
) = get_native_function_schema_registrations(
|
|
native_functions=[
|
|
DEFAULT_NATIVE_FUNCTION,
|
|
self.custom_native_function,
|
|
custom2_native_function,
|
|
],
|
|
schema_selector=self.selector,
|
|
)
|
|
self.assertEqual(aten_registrations, ['m.def("func() -> bool", {});\n'])
|
|
self.assertEqual(
|
|
custom_registrations,
|
|
"""
|
|
TORCH_LIBRARY(custom, m) {
|
|
m.def("func() -> bool", {});
|
|
|
|
};
|
|
TORCH_LIBRARY(custom2, m) {
|
|
m.def("func() -> bool", {});
|
|
|
|
};""",
|
|
)
|
|
|
|
|
|
class TestGenNativeFunctionDeclaration(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml(
|
|
{"func": "op_1() -> bool", "dispatch": {"CPU": "kernel_1"}},
|
|
loc=Location(__file__, 1),
|
|
valid_tags=set(),
|
|
)
|
|
self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml(
|
|
{
|
|
"func": "op_2() -> bool",
|
|
"dispatch": {"CPU": "kernel_2", "QuantizedCPU": "custom::kernel_3"},
|
|
},
|
|
loc=Location(__file__, 1),
|
|
valid_tags=set(),
|
|
)
|
|
|
|
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = {
|
|
DispatchKey.CPU: {},
|
|
DispatchKey.QuantizedCPU: {},
|
|
}
|
|
BackendIndex.grow_index(backend_indices, op_1_backend_index)
|
|
BackendIndex.grow_index(backend_indices, op_2_backend_index)
|
|
self.backend_indices = {
|
|
k: BackendIndex(
|
|
dispatch_key=k,
|
|
use_out_as_primary=True,
|
|
external=False,
|
|
device_guard=False,
|
|
index=backend_indices[k],
|
|
)
|
|
for k in backend_indices
|
|
}
|
|
|
|
def test_native_function_declaration_1_op_2_ns_error(self) -> None:
|
|
with self.assertRaises(AssertionError):
|
|
get_native_function_declarations(
|
|
grouped_native_functions=[
|
|
self.op_1_native_function,
|
|
self.op_2_native_function,
|
|
],
|
|
backend_indices=self.backend_indices,
|
|
native_function_decl_gen=dest.compute_native_function_declaration,
|
|
)
|
|
|
|
def test_native_function_declaration_1_op_1_ns_valid(self) -> None:
|
|
self.assertIsInstance(self.op_1_native_function, NativeFunction)
|
|
declaration = get_native_function_declarations(
|
|
grouped_native_functions=[
|
|
self.op_1_native_function,
|
|
],
|
|
backend_indices=self.backend_indices,
|
|
native_function_decl_gen=dest.compute_native_function_declaration,
|
|
)
|
|
target = """
|
|
namespace at {
|
|
namespace native {
|
|
TORCH_API bool kernel_1();
|
|
} // namespace native
|
|
} // namespace at
|
|
"""
|
|
self.assertEqual("\n".join(declaration), target)
|
|
|
|
|
|
# Test for native_function_generation
|
|
class TestNativeFunctionGeneratrion(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
self.native_functions: list[NativeFunction] = []
|
|
self.backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = (
|
|
defaultdict(dict)
|
|
)
|
|
yaml_entry = """
|
|
- func: op(Tensor self) -> Tensor
|
|
dispatch:
|
|
CompositeExplicitAutograd: op
|
|
autogen: op.out
|
|
"""
|
|
es = yaml.load(yaml_entry, Loader=LineLoader)
|
|
self.one_return_func, m = NativeFunction.from_yaml(
|
|
es[0], loc=Location(__file__, 1), valid_tags=set()
|
|
)
|
|
|
|
BackendIndex.grow_index(self.backend_indices, m)
|
|
|
|
self.two_returns_func, two_returns_backend_index = NativeFunction.from_yaml(
|
|
{
|
|
"func": "op_2() -> (Tensor, Tensor)",
|
|
"dispatch": {"CPU": "kernel_1"},
|
|
"autogen": "op_2.out",
|
|
},
|
|
loc=Location(__file__, 1),
|
|
valid_tags=set(),
|
|
)
|
|
BackendIndex.grow_index(self.backend_indices, two_returns_backend_index)
|
|
|
|
self.core_func, core_func_index = NativeFunction.from_yaml(
|
|
{
|
|
"func": "op_3.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor",
|
|
"autogen": "op_3.vec_out",
|
|
"tags": ["core"],
|
|
},
|
|
loc=Location(__file__, 1),
|
|
valid_tags={"core"},
|
|
)
|
|
BackendIndex.grow_index(self.backend_indices, core_func_index)
|
|
|
|
def test_functional_variant_autogen_out_variant(self) -> None:
|
|
native_functions = [self.one_return_func]
|
|
add_generated_native_functions(native_functions, self.backend_indices)
|
|
self.assertEqual(len(native_functions), 2)
|
|
self.assertEqual(
|
|
str(native_functions[1].func),
|
|
"op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)",
|
|
)
|
|
op_name = native_functions[1].func.name
|
|
backend_metadata = self.backend_indices[DispatchKey.CompositeExplicitAutograd][
|
|
op_name
|
|
]
|
|
self.assertEqual(backend_metadata.kernel, "op_out")
|
|
|
|
def test_functional_variant_autogen_out_variant_two_returns(self) -> None:
|
|
native_functions = [self.two_returns_func]
|
|
add_generated_native_functions(native_functions, self.backend_indices)
|
|
self.assertEqual(len(native_functions), 2)
|
|
self.assertEqual(
|
|
str(native_functions[1].func),
|
|
"op_2.out(*, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))",
|
|
)
|
|
op_name = native_functions[1].func.name
|
|
backend_metadata = self.backend_indices[DispatchKey.CompositeExplicitAutograd][
|
|
op_name
|
|
]
|
|
self.assertEqual(backend_metadata.kernel, "op_2_out")
|
|
|
|
def test_functional_variant_autogen_out_variant_core(self) -> None:
|
|
"""
|
|
Tests autogen of out variants for core-tageed ops that are CompositeImplicitAutograd.
|
|
"""
|
|
native_functions = [self.core_func]
|
|
add_generated_native_functions(native_functions, self.backend_indices)
|
|
print(native_functions)
|
|
self.assertEqual(len(native_functions), 2)
|
|
self.assertEqual(
|
|
str(native_functions[1].func),
|
|
"op_3.vec_out(Tensor input, SymInt[]? output_size, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!)",
|
|
)
|
|
|
|
|
|
# Test for static_dispatch
|
|
class TestStaticDispatchGeneratrion(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
self.backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = (
|
|
defaultdict(dict)
|
|
)
|
|
yaml_entry = """
|
|
- func: op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
|
dispatch:
|
|
CompositeExplicitAutograd: op
|
|
"""
|
|
es = yaml.load(yaml_entry, Loader=LineLoader)
|
|
self.one_return_func, m = NativeFunction.from_yaml(
|
|
es[0], loc=Location(__file__, 1), valid_tags=set()
|
|
)
|
|
|
|
BackendIndex.grow_index(self.backend_indices, m)
|
|
dispatch_key = DispatchKey.CompositeExplicitAutograd
|
|
self.assertTrue(dispatch_key in self.backend_indices)
|
|
self.indices = [
|
|
BackendIndex(
|
|
dispatch_key=dispatch_key,
|
|
use_out_as_primary=True,
|
|
external=False,
|
|
device_guard=False,
|
|
index=self.backend_indices[dispatch_key],
|
|
)
|
|
]
|
|
|
|
def test_op_with_1_backend_generates_static_dispatch(self) -> None:
|
|
disp_sig = DispatcherSignature.from_schema(self.one_return_func.func)
|
|
with native_function_manager(self.one_return_func):
|
|
out = static_dispatch(
|
|
sig=disp_sig,
|
|
f=self.one_return_func,
|
|
backend_indices=self.indices,
|
|
)
|
|
self.assertEqual(
|
|
out, "return at::compositeexplicitautograd::op_out(out, self);"
|
|
)
|
|
|
|
def test_op_with_cpp_sig_generates_static_dispatch(self) -> None:
|
|
sig_group = CppSignatureGroup.from_native_function(
|
|
self.one_return_func,
|
|
method=False,
|
|
fallback_binding=self.one_return_func.manual_cpp_binding,
|
|
)
|
|
# cpp signature puts out at the front
|
|
with native_function_manager(self.one_return_func):
|
|
out = static_dispatch(
|
|
sig=sig_group.signature,
|
|
f=self.one_return_func,
|
|
backend_indices=self.indices,
|
|
)
|
|
self.assertEqual(
|
|
out, "return at::compositeexplicitautograd::op_out(out, self);"
|
|
)
|
|
|
|
|
|
# Represents the most basic NativeFunction. Use dataclasses.replace()
|
|
# to edit for use.
|
|
DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml(
|
|
{"func": "func() -> bool"},
|
|
loc=Location(__file__, 1),
|
|
valid_tags=set(),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|