mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[HOP] support generating schema for hop (#133521)
Add a way of generating a FunctionSchema from example values because hop's schema varies even for the same hop. We didn't use torch._C.FunctionSchema because we cannot construct the classes directly (e.g. "__init__" cannot be used for torch._C.FunctionSchema). Also extending the Basic types in c++ seems not that easy. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133521 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
dd5a7c8397
commit
6835f20d20
@ -185,6 +185,15 @@ def is_training_ir_test(test_name):
|
||||
)
|
||||
|
||||
|
||||
def get_hop_schema(ep: torch.export.ExportedProgram):
|
||||
hop_node = next(
|
||||
node
|
||||
for node in ep.graph.nodes
|
||||
if isinstance(node.target, torch._ops.HigherOrderOperator)
|
||||
)
|
||||
return torch._library.utils.hop_schema_from_fx_node(hop_node)
|
||||
|
||||
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
||||
class TestDynamismExpression(TestCase):
|
||||
def test_export_inline_constraints(self):
|
||||
@ -4181,6 +4190,11 @@ def forward(self, x):
|
||||
dim0 = torch.export.Dim("dim0", min=3)
|
||||
inp = torch.ones(6, 4)
|
||||
ep = export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0}})
|
||||
schema = get_hop_schema(ep)
|
||||
self.assertExpectedInline(
|
||||
str(schema),
|
||||
"""cond(SymBool pred, GraphModule true_fn, GraphModule false_fn, Tensor[2] operands) -> Tensor[1]""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
ep.graph_module.code.strip(),
|
||||
"""\
|
||||
@ -5570,6 +5584,11 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
||||
add = torch.ops.aten.add.Tensor(cos, getitem); cos = getitem = None
|
||||
return (add,)""",
|
||||
)
|
||||
schema = get_hop_schema(ep)
|
||||
self.assertExpectedInline(
|
||||
str(schema),
|
||||
"""cond(Tensor pred, GraphModule true_fn, GraphModule false_fn, Tensor[3] operands) -> Tensor[1]""",
|
||||
)
|
||||
|
||||
cond_top_level_nn_module_stack = [
|
||||
node.meta["nn_module_stack"]
|
||||
@ -5655,6 +5674,12 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
||||
strict=False,
|
||||
)
|
||||
|
||||
schema = get_hop_schema(exported_program)
|
||||
self.assertExpectedInline(
|
||||
str(schema),
|
||||
"""cond(Tensor pred, GraphModule true_fn, GraphModule false_fn, Tensor[3] operands) -> Tensor[1]""", # noqa: B950
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
str(exported_program.graph_module.code.strip()),
|
||||
"""\
|
||||
@ -6139,6 +6164,12 @@ def forward(self, x):
|
||||
(torch.randn(4), torch.randn(4), torch.randn(4)),
|
||||
pre_dispatch=True,
|
||||
)
|
||||
|
||||
schema = get_hop_schema(ep)
|
||||
self.assertExpectedInline(
|
||||
str(schema),
|
||||
"""cond(Tensor pred, GraphModule true_fn, GraphModule false_fn, Tensor[3] operands) -> Tensor[1]""",
|
||||
)
|
||||
# test cond subgraph
|
||||
expected_names_and_ops = [
|
||||
("mul_2", "placeholder"),
|
||||
|
@ -3711,6 +3711,143 @@ def forward(self, l_inp_, l_tmp_):
|
||||
self.assertEqual(cnt.frame_count, 3)
|
||||
|
||||
|
||||
_hop_schema_test_schema_types = [
|
||||
"bool",
|
||||
"int",
|
||||
"float",
|
||||
"str",
|
||||
"Tensor",
|
||||
"SymInt",
|
||||
"SymBool",
|
||||
"GraphModule",
|
||||
"ScriptObj",
|
||||
]
|
||||
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
||||
class TestHopSchema(TestCase):
|
||||
def _get_example_val(self, ty: str):
|
||||
from torch.fx.experimental.sym_node import SymNode
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
|
||||
def create_symtype(cls, pytype, shape_env, val):
|
||||
from torch._dynamo.source import ConstantSource
|
||||
|
||||
symbol = shape_env.create_symbol(
|
||||
val,
|
||||
source=ConstantSource(
|
||||
f"__testing_hop_schema{len(shape_env.var_to_val)}"
|
||||
),
|
||||
)
|
||||
return cls(SymNode(symbol, shape_env, pytype, hint=val))
|
||||
|
||||
if ty == "bool":
|
||||
return True
|
||||
elif ty == "int":
|
||||
return 1
|
||||
elif ty == "float":
|
||||
return 1.0
|
||||
elif ty == "str":
|
||||
return "foo"
|
||||
elif ty == "Tensor":
|
||||
return torch.tensor(1)
|
||||
elif ty == "SymInt":
|
||||
shape_env = ShapeEnv()
|
||||
return create_symtype(torch.SymInt, int, shape_env, 1)
|
||||
elif ty == "SymBool":
|
||||
shape_env = ShapeEnv()
|
||||
return create_symtype(torch.SymBool, bool, shape_env, True)
|
||||
elif ty == "GraphModule":
|
||||
|
||||
def f(x):
|
||||
return x.sin()
|
||||
|
||||
return make_fx(f)(torch.ones(1))
|
||||
elif ty == "ScriptObj":
|
||||
from torch.testing._internal.torchbind_impls import (
|
||||
init_torchbind_implementations,
|
||||
)
|
||||
|
||||
init_torchbind_implementations()
|
||||
foo = torch.classes._TorchScriptTesting._Foo(3, 4)
|
||||
return foo
|
||||
else:
|
||||
raise NotImplementedError(ty)
|
||||
|
||||
@parametrize("schema_type", _hop_schema_test_schema_types)
|
||||
def test_type_gen(self, schema_type):
|
||||
from torchgen.gen_schema_utils import TypeGen
|
||||
|
||||
example_val = self._get_example_val(schema_type)
|
||||
ty = TypeGen.from_example(example_val)
|
||||
# Test the generated type can be parsed
|
||||
self.assertEqual(ty.parse(str(ty)), ty)
|
||||
|
||||
@parametrize("schema_type", _hop_schema_test_schema_types)
|
||||
def test_list_gen(self, schema_type):
|
||||
from torchgen.gen_schema_utils import TypeGen
|
||||
|
||||
example_val = self._get_example_val(schema_type)
|
||||
li1 = [example_val]
|
||||
li2 = [example_val, example_val]
|
||||
ty1 = TypeGen.from_example(li1)
|
||||
ty2 = TypeGen.from_example(li1)
|
||||
self.assertEqual(ty1.parse(str(ty1)), ty1)
|
||||
self.assertEqual(ty2.parse(str(ty2)), ty2)
|
||||
|
||||
def test_function_schema_gen(self):
|
||||
from torchgen.gen_schema_utils import FunctionSchemaGen
|
||||
|
||||
inps = [
|
||||
(schema_type + "_v", self._get_example_val(schema_type))
|
||||
for schema_type in _hop_schema_test_schema_types
|
||||
]
|
||||
op_name = "test_op"
|
||||
schema1 = FunctionSchemaGen.from_example("test_op1", inps, torch.ones(1))
|
||||
schema2 = FunctionSchemaGen.from_example(
|
||||
"test_op2",
|
||||
inps,
|
||||
[
|
||||
torch.ones(1),
|
||||
],
|
||||
)
|
||||
schema3 = FunctionSchemaGen.from_example(
|
||||
"test_op3", inps, [torch.ones(1), torch.ones(1)]
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(schema1),
|
||||
"""test_op1(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> Tensor""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(schema2),
|
||||
"""test_op2(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> Tensor""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(schema3),
|
||||
"""test_op3(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> (Tensor, Tensor)""", # noqa: B950,
|
||||
)
|
||||
self.assertEqual(schema1.parse(str(schema1)), schema1)
|
||||
self.assertEqual(schema2.parse(str(schema2)), schema2)
|
||||
self.assertEqual(schema3.parse(str(schema3)), schema3)
|
||||
|
||||
def test_while_loop_schema_gen(self):
|
||||
fn, inp = WHILE_LOOP_TESTS["simple_with_linear"]
|
||||
graph = make_fx(fn)(*inp).graph
|
||||
while_loop_node = next(
|
||||
node
|
||||
for node in graph.nodes
|
||||
if node.op == "call_function"
|
||||
and node.target is torch.ops.higher_order.while_loop
|
||||
)
|
||||
schema = torch._library.utils.hop_schema_from_fx_node(while_loop_node)
|
||||
self.assertExpectedInline(
|
||||
str(schema),
|
||||
"""while_loop(GraphModule cond_fn, GraphModule body_fn, Tensor[2] carried_inputs, Tensor[3] additional_inputs) -> Tensor[2]""", # noqa: B950
|
||||
)
|
||||
self.assertEqual(schema.parse(str(schema)), schema)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestHopSchema)
|
||||
instantiate_parametrized_tests(TestControlFlowTraced)
|
||||
|
||||
instantiate_parametrized_tests(TestControlFlow)
|
||||
|
@ -37,6 +37,17 @@ def wrap_combine_fn_flat(*args, combine_fn, spec, num_leaves):
|
||||
return combined_leaves
|
||||
|
||||
|
||||
class AssociativeScanOp(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("associative_scan")
|
||||
|
||||
def __call__(self, combine_fn, input, dim):
|
||||
return super().__call__(combine_fn, input, dim)
|
||||
|
||||
|
||||
associative_scan_op = AssociativeScanOp()
|
||||
|
||||
|
||||
def associative_scan(
|
||||
combine_fn: Callable[[pytree.PyTree, pytree.PyTree], pytree.PyTree],
|
||||
input: pytree.PyTree,
|
||||
@ -110,9 +121,6 @@ def associative_scan(
|
||||
return pytree.tree_unflatten(result_flat, spec)
|
||||
|
||||
|
||||
associative_scan_op = HigherOrderOperator("associative_scan")
|
||||
|
||||
|
||||
def trace_associative_scan(
|
||||
proxy_mode, func_overload, combine_fn: Callable, input: List[torch.Tensor], dim: int
|
||||
):
|
||||
|
@ -40,6 +40,22 @@ from .utils import _from_fun, create_fw_bw_graph
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
"""
|
||||
We're going to define a `cond_op` operation.
|
||||
In order to do this, we need implementations for each of the dispatch keys.
|
||||
"""
|
||||
|
||||
|
||||
class CondOp(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("cond")
|
||||
|
||||
def __call__(self, pred, true_fn, false_fn, operands):
|
||||
return super().__call__(pred, true_fn, false_fn, operands)
|
||||
|
||||
|
||||
cond_op = CondOp()
|
||||
|
||||
|
||||
@exposed_in("torch")
|
||||
def cond(pred, true_fn, false_fn, operands):
|
||||
@ -160,13 +176,6 @@ def cond(pred, true_fn, false_fn, operands):
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
We're going to define a `cond_op` operation.
|
||||
In order to do this, we need implementations for each of the dispatch keys.
|
||||
"""
|
||||
cond_op = HigherOrderOperator("cond")
|
||||
|
||||
|
||||
def create_fw_bw_graph_branches(true_fn, false_fn, *operands):
|
||||
# See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py
|
||||
|
||||
|
@ -201,6 +201,44 @@ def zip_schema(
|
||||
return
|
||||
|
||||
|
||||
def hop_schema_from_fx_node(node):
|
||||
from torchgen.gen_schema_utils import FunctionSchemaGen
|
||||
|
||||
hop = node.target
|
||||
if not isinstance(hop, torch._ops.HigherOrderOperator):
|
||||
raise RuntimeError("fx_node's target must be a hop.")
|
||||
|
||||
def _collect_example_val(node):
|
||||
meta_val = node.meta.get("val", None)
|
||||
if meta_val is None:
|
||||
assert node.op == "get_attr"
|
||||
meta_val = getattr(node.graph.owning_module, node.target)
|
||||
return meta_val
|
||||
|
||||
example_inputs = []
|
||||
for arg in node.args:
|
||||
if isinstance(arg, (torch.fx.Node, torch.fx.node.Node)):
|
||||
example_inputs.append(_collect_example_val(arg))
|
||||
elif isinstance(
|
||||
arg, (torch.fx.immutable_collections.immutable_list, list, tuple)
|
||||
):
|
||||
example_inputs.append([_collect_example_val(x) for x in arg])
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported arg type {type(arg)}")
|
||||
|
||||
# Bound the arguments to make sure number of inputs are correct
|
||||
bound_args: inspect.BoundArguments = inspect.signature(hop.__call__).bind(
|
||||
*example_inputs
|
||||
)
|
||||
|
||||
# We treat example_output as a single value in return. This is to differentiate 1. return a single val
|
||||
# vs 2. return a tuple with one element.
|
||||
example_output = _collect_example_val(node)
|
||||
return FunctionSchemaGen.from_example(
|
||||
hop._name, tuple(bound_args.arguments.items()), (list(example_output),)
|
||||
)
|
||||
|
||||
|
||||
def can_generate_trivial_fake_impl(op: OpOverload) -> bool:
|
||||
assert isinstance(op, OpOverload)
|
||||
if is_builtin(op):
|
||||
|
97
torchgen/gen_schema_utils.py
Normal file
97
torchgen/gen_schema_utils.py
Normal file
@ -0,0 +1,97 @@
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
from torchgen.model import (
|
||||
Annotation,
|
||||
Argument,
|
||||
Arguments,
|
||||
BaseOperatorName,
|
||||
BaseTy,
|
||||
BaseType,
|
||||
CustomClassType,
|
||||
FunctionSchema,
|
||||
ListType,
|
||||
OperatorName,
|
||||
Return,
|
||||
)
|
||||
|
||||
|
||||
# Note: These aren't actually used in torchgen, they're some utilities for generating a schema
|
||||
# from real arguments. For example, this is used to generate HigherOrderOperators' schema since
|
||||
# their schemas can vary for different instances of the same HOP.
|
||||
|
||||
|
||||
class TypeGen:
|
||||
convert_to_base_ty = {
|
||||
int: BaseTy.int,
|
||||
float: BaseTy.float,
|
||||
str: BaseTy.str,
|
||||
bool: BaseTy.bool,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_example(obj: Any) -> Union[BaseType, ListType, CustomClassType]:
|
||||
import torch
|
||||
|
||||
if isinstance(obj, torch.fx.GraphModule):
|
||||
return BaseType(BaseTy.GraphModule)
|
||||
elif isinstance(obj, torch.Tensor):
|
||||
return BaseType(BaseTy.Tensor)
|
||||
elif isinstance(obj, torch.SymInt):
|
||||
return BaseType(BaseTy.SymInt)
|
||||
elif isinstance(obj, torch.SymBool):
|
||||
return BaseType(BaseTy.SymBool)
|
||||
elif isinstance(obj, torch.ScriptObject):
|
||||
return CustomClassType(obj._type().name()) # type: ignore[attr-defined]
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
assert len(obj) > 0
|
||||
all_base_tys = [TypeGen.from_example(x) for x in obj]
|
||||
if len(set(all_base_tys)) > 1:
|
||||
raise RuntimeError(
|
||||
f"Cannot generate schema for a seqeunce of args of heterogeneous types: {all_base_tys}. "
|
||||
"Consider unpacking the argument and give proper names to them if possible "
|
||||
"instead of using *args."
|
||||
)
|
||||
return ListType(all_base_tys[0], len(obj))
|
||||
tp = type(obj)
|
||||
if tp not in TypeGen.convert_to_base_ty:
|
||||
raise RuntimeError(f"unsupported type {tp}")
|
||||
return BaseType(TypeGen.convert_to_base_ty[tp])
|
||||
|
||||
|
||||
class ReturnGen:
|
||||
@staticmethod
|
||||
def from_example(
|
||||
name: Optional[str], obj: Any, annotation: Optional[Annotation]
|
||||
) -> Return:
|
||||
return Return(name, TypeGen.from_example(obj), annotation)
|
||||
|
||||
|
||||
class ArgumentGen:
|
||||
@staticmethod
|
||||
def from_example(
|
||||
name: str, obj: Any, default: Optional[str], annotation: Optional[Annotation]
|
||||
) -> Argument:
|
||||
return Argument(
|
||||
name, TypeGen.from_example(obj), default=default, annotation=annotation
|
||||
)
|
||||
|
||||
|
||||
class FunctionSchemaGen:
|
||||
@staticmethod
|
||||
def from_example(
|
||||
op_name: str,
|
||||
example_inputs: Tuple[Tuple[str, Any], ...],
|
||||
example_outputs: Tuple[Any, ...],
|
||||
) -> FunctionSchema:
|
||||
args = []
|
||||
for name, inp in example_inputs:
|
||||
args.append(ArgumentGen.from_example(name, inp, None, None))
|
||||
# ignore the annotations and other attributes for now, we could add more when needed.
|
||||
arguments = Arguments(
|
||||
tuple(), None, tuple(args), tuple(), None, tuple(), tuple()
|
||||
)
|
||||
returns = tuple(
|
||||
ReturnGen.from_example(None, out, None) for out in example_outputs
|
||||
)
|
||||
op_name = OperatorName(BaseOperatorName(op_name, False, False, False), "")
|
||||
return FunctionSchema(op_name, arguments, returns)
|
@ -1880,7 +1880,9 @@ class BaseTy(Enum):
|
||||
Storage = auto()
|
||||
Stream = auto()
|
||||
SymInt = auto()
|
||||
SymBool = auto()
|
||||
ConstQuantizerPtr = auto() # TODO: rename
|
||||
GraphModule = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
Reference in New Issue
Block a user