[HOP] support generating schema for hop (#133521)

Add a way of generating a FunctionSchema from example values because hop's schema varies even for the same hop.

We didn't use torch._C.FunctionSchema because we cannot construct the classes directly (e.g. "__init__" cannot be used for torch._C.FunctionSchema). Also extending the Basic types in c++ seems not that easy.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133521
Approved by: https://github.com/zou3519
This commit is contained in:
Yidi Wu
2024-08-20 15:52:45 -07:00
committed by PyTorch MergeBot
parent dd5a7c8397
commit 6835f20d20
7 changed files with 332 additions and 10 deletions

View File

@ -185,6 +185,15 @@ def is_training_ir_test(test_name):
) )
def get_hop_schema(ep: torch.export.ExportedProgram):
hop_node = next(
node
for node in ep.graph.nodes
if isinstance(node.target, torch._ops.HigherOrderOperator)
)
return torch._library.utils.hop_schema_from_fx_node(hop_node)
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestDynamismExpression(TestCase): class TestDynamismExpression(TestCase):
def test_export_inline_constraints(self): def test_export_inline_constraints(self):
@ -4181,6 +4190,11 @@ def forward(self, x):
dim0 = torch.export.Dim("dim0", min=3) dim0 = torch.export.Dim("dim0", min=3)
inp = torch.ones(6, 4) inp = torch.ones(6, 4)
ep = export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0}}) ep = export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0}})
schema = get_hop_schema(ep)
self.assertExpectedInline(
str(schema),
"""cond(SymBool pred, GraphModule true_fn, GraphModule false_fn, Tensor[2] operands) -> Tensor[1]""",
)
self.assertExpectedInline( self.assertExpectedInline(
ep.graph_module.code.strip(), ep.graph_module.code.strip(),
"""\ """\
@ -5570,6 +5584,11 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
add = torch.ops.aten.add.Tensor(cos, getitem); cos = getitem = None add = torch.ops.aten.add.Tensor(cos, getitem); cos = getitem = None
return (add,)""", return (add,)""",
) )
schema = get_hop_schema(ep)
self.assertExpectedInline(
str(schema),
"""cond(Tensor pred, GraphModule true_fn, GraphModule false_fn, Tensor[3] operands) -> Tensor[1]""",
)
cond_top_level_nn_module_stack = [ cond_top_level_nn_module_stack = [
node.meta["nn_module_stack"] node.meta["nn_module_stack"]
@ -5655,6 +5674,12 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
strict=False, strict=False,
) )
schema = get_hop_schema(exported_program)
self.assertExpectedInline(
str(schema),
"""cond(Tensor pred, GraphModule true_fn, GraphModule false_fn, Tensor[3] operands) -> Tensor[1]""", # noqa: B950
)
self.assertExpectedInline( self.assertExpectedInline(
str(exported_program.graph_module.code.strip()), str(exported_program.graph_module.code.strip()),
"""\ """\
@ -6139,6 +6164,12 @@ def forward(self, x):
(torch.randn(4), torch.randn(4), torch.randn(4)), (torch.randn(4), torch.randn(4), torch.randn(4)),
pre_dispatch=True, pre_dispatch=True,
) )
schema = get_hop_schema(ep)
self.assertExpectedInline(
str(schema),
"""cond(Tensor pred, GraphModule true_fn, GraphModule false_fn, Tensor[3] operands) -> Tensor[1]""",
)
# test cond subgraph # test cond subgraph
expected_names_and_ops = [ expected_names_and_ops = [
("mul_2", "placeholder"), ("mul_2", "placeholder"),

View File

@ -3711,6 +3711,143 @@ def forward(self, l_inp_, l_tmp_):
self.assertEqual(cnt.frame_count, 3) self.assertEqual(cnt.frame_count, 3)
_hop_schema_test_schema_types = [
"bool",
"int",
"float",
"str",
"Tensor",
"SymInt",
"SymBool",
"GraphModule",
"ScriptObj",
]
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
class TestHopSchema(TestCase):
def _get_example_val(self, ty: str):
from torch.fx.experimental.sym_node import SymNode
from torch.fx.experimental.symbolic_shapes import ShapeEnv
def create_symtype(cls, pytype, shape_env, val):
from torch._dynamo.source import ConstantSource
symbol = shape_env.create_symbol(
val,
source=ConstantSource(
f"__testing_hop_schema{len(shape_env.var_to_val)}"
),
)
return cls(SymNode(symbol, shape_env, pytype, hint=val))
if ty == "bool":
return True
elif ty == "int":
return 1
elif ty == "float":
return 1.0
elif ty == "str":
return "foo"
elif ty == "Tensor":
return torch.tensor(1)
elif ty == "SymInt":
shape_env = ShapeEnv()
return create_symtype(torch.SymInt, int, shape_env, 1)
elif ty == "SymBool":
shape_env = ShapeEnv()
return create_symtype(torch.SymBool, bool, shape_env, True)
elif ty == "GraphModule":
def f(x):
return x.sin()
return make_fx(f)(torch.ones(1))
elif ty == "ScriptObj":
from torch.testing._internal.torchbind_impls import (
init_torchbind_implementations,
)
init_torchbind_implementations()
foo = torch.classes._TorchScriptTesting._Foo(3, 4)
return foo
else:
raise NotImplementedError(ty)
@parametrize("schema_type", _hop_schema_test_schema_types)
def test_type_gen(self, schema_type):
from torchgen.gen_schema_utils import TypeGen
example_val = self._get_example_val(schema_type)
ty = TypeGen.from_example(example_val)
# Test the generated type can be parsed
self.assertEqual(ty.parse(str(ty)), ty)
@parametrize("schema_type", _hop_schema_test_schema_types)
def test_list_gen(self, schema_type):
from torchgen.gen_schema_utils import TypeGen
example_val = self._get_example_val(schema_type)
li1 = [example_val]
li2 = [example_val, example_val]
ty1 = TypeGen.from_example(li1)
ty2 = TypeGen.from_example(li1)
self.assertEqual(ty1.parse(str(ty1)), ty1)
self.assertEqual(ty2.parse(str(ty2)), ty2)
def test_function_schema_gen(self):
from torchgen.gen_schema_utils import FunctionSchemaGen
inps = [
(schema_type + "_v", self._get_example_val(schema_type))
for schema_type in _hop_schema_test_schema_types
]
op_name = "test_op"
schema1 = FunctionSchemaGen.from_example("test_op1", inps, torch.ones(1))
schema2 = FunctionSchemaGen.from_example(
"test_op2",
inps,
[
torch.ones(1),
],
)
schema3 = FunctionSchemaGen.from_example(
"test_op3", inps, [torch.ones(1), torch.ones(1)]
)
self.assertExpectedInline(
str(schema1),
"""test_op1(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> Tensor""", # noqa: B950
)
self.assertExpectedInline(
str(schema2),
"""test_op2(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> Tensor""", # noqa: B950
)
self.assertExpectedInline(
str(schema3),
"""test_op3(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> (Tensor, Tensor)""", # noqa: B950,
)
self.assertEqual(schema1.parse(str(schema1)), schema1)
self.assertEqual(schema2.parse(str(schema2)), schema2)
self.assertEqual(schema3.parse(str(schema3)), schema3)
def test_while_loop_schema_gen(self):
fn, inp = WHILE_LOOP_TESTS["simple_with_linear"]
graph = make_fx(fn)(*inp).graph
while_loop_node = next(
node
for node in graph.nodes
if node.op == "call_function"
and node.target is torch.ops.higher_order.while_loop
)
schema = torch._library.utils.hop_schema_from_fx_node(while_loop_node)
self.assertExpectedInline(
str(schema),
"""while_loop(GraphModule cond_fn, GraphModule body_fn, Tensor[2] carried_inputs, Tensor[3] additional_inputs) -> Tensor[2]""", # noqa: B950
)
self.assertEqual(schema.parse(str(schema)), schema)
instantiate_parametrized_tests(TestHopSchema)
instantiate_parametrized_tests(TestControlFlowTraced) instantiate_parametrized_tests(TestControlFlowTraced)
instantiate_parametrized_tests(TestControlFlow) instantiate_parametrized_tests(TestControlFlow)

View File

@ -37,6 +37,17 @@ def wrap_combine_fn_flat(*args, combine_fn, spec, num_leaves):
return combined_leaves return combined_leaves
class AssociativeScanOp(HigherOrderOperator):
def __init__(self):
super().__init__("associative_scan")
def __call__(self, combine_fn, input, dim):
return super().__call__(combine_fn, input, dim)
associative_scan_op = AssociativeScanOp()
def associative_scan( def associative_scan(
combine_fn: Callable[[pytree.PyTree, pytree.PyTree], pytree.PyTree], combine_fn: Callable[[pytree.PyTree, pytree.PyTree], pytree.PyTree],
input: pytree.PyTree, input: pytree.PyTree,
@ -110,9 +121,6 @@ def associative_scan(
return pytree.tree_unflatten(result_flat, spec) return pytree.tree_unflatten(result_flat, spec)
associative_scan_op = HigherOrderOperator("associative_scan")
def trace_associative_scan( def trace_associative_scan(
proxy_mode, func_overload, combine_fn: Callable, input: List[torch.Tensor], dim: int proxy_mode, func_overload, combine_fn: Callable, input: List[torch.Tensor], dim: int
): ):

View File

@ -40,6 +40,22 @@ from .utils import _from_fun, create_fw_bw_graph
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
"""
We're going to define a `cond_op` operation.
In order to do this, we need implementations for each of the dispatch keys.
"""
class CondOp(HigherOrderOperator):
def __init__(self):
super().__init__("cond")
def __call__(self, pred, true_fn, false_fn, operands):
return super().__call__(pred, true_fn, false_fn, operands)
cond_op = CondOp()
@exposed_in("torch") @exposed_in("torch")
def cond(pred, true_fn, false_fn, operands): def cond(pred, true_fn, false_fn, operands):
@ -160,13 +176,6 @@ def cond(pred, true_fn, false_fn, operands):
) )
"""
We're going to define a `cond_op` operation.
In order to do this, we need implementations for each of the dispatch keys.
"""
cond_op = HigherOrderOperator("cond")
def create_fw_bw_graph_branches(true_fn, false_fn, *operands): def create_fw_bw_graph_branches(true_fn, false_fn, *operands):
# See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py # See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py

View File

@ -201,6 +201,44 @@ def zip_schema(
return return
def hop_schema_from_fx_node(node):
from torchgen.gen_schema_utils import FunctionSchemaGen
hop = node.target
if not isinstance(hop, torch._ops.HigherOrderOperator):
raise RuntimeError("fx_node's target must be a hop.")
def _collect_example_val(node):
meta_val = node.meta.get("val", None)
if meta_val is None:
assert node.op == "get_attr"
meta_val = getattr(node.graph.owning_module, node.target)
return meta_val
example_inputs = []
for arg in node.args:
if isinstance(arg, (torch.fx.Node, torch.fx.node.Node)):
example_inputs.append(_collect_example_val(arg))
elif isinstance(
arg, (torch.fx.immutable_collections.immutable_list, list, tuple)
):
example_inputs.append([_collect_example_val(x) for x in arg])
else:
raise RuntimeError(f"Unsupported arg type {type(arg)}")
# Bound the arguments to make sure number of inputs are correct
bound_args: inspect.BoundArguments = inspect.signature(hop.__call__).bind(
*example_inputs
)
# We treat example_output as a single value in return. This is to differentiate 1. return a single val
# vs 2. return a tuple with one element.
example_output = _collect_example_val(node)
return FunctionSchemaGen.from_example(
hop._name, tuple(bound_args.arguments.items()), (list(example_output),)
)
def can_generate_trivial_fake_impl(op: OpOverload) -> bool: def can_generate_trivial_fake_impl(op: OpOverload) -> bool:
assert isinstance(op, OpOverload) assert isinstance(op, OpOverload)
if is_builtin(op): if is_builtin(op):

View File

@ -0,0 +1,97 @@
from typing import Any, Optional, Tuple, Union
from torchgen.model import (
Annotation,
Argument,
Arguments,
BaseOperatorName,
BaseTy,
BaseType,
CustomClassType,
FunctionSchema,
ListType,
OperatorName,
Return,
)
# Note: These aren't actually used in torchgen, they're some utilities for generating a schema
# from real arguments. For example, this is used to generate HigherOrderOperators' schema since
# their schemas can vary for different instances of the same HOP.
class TypeGen:
convert_to_base_ty = {
int: BaseTy.int,
float: BaseTy.float,
str: BaseTy.str,
bool: BaseTy.bool,
}
@staticmethod
def from_example(obj: Any) -> Union[BaseType, ListType, CustomClassType]:
import torch
if isinstance(obj, torch.fx.GraphModule):
return BaseType(BaseTy.GraphModule)
elif isinstance(obj, torch.Tensor):
return BaseType(BaseTy.Tensor)
elif isinstance(obj, torch.SymInt):
return BaseType(BaseTy.SymInt)
elif isinstance(obj, torch.SymBool):
return BaseType(BaseTy.SymBool)
elif isinstance(obj, torch.ScriptObject):
return CustomClassType(obj._type().name()) # type: ignore[attr-defined]
elif isinstance(obj, (list, tuple)):
assert len(obj) > 0
all_base_tys = [TypeGen.from_example(x) for x in obj]
if len(set(all_base_tys)) > 1:
raise RuntimeError(
f"Cannot generate schema for a seqeunce of args of heterogeneous types: {all_base_tys}. "
"Consider unpacking the argument and give proper names to them if possible "
"instead of using *args."
)
return ListType(all_base_tys[0], len(obj))
tp = type(obj)
if tp not in TypeGen.convert_to_base_ty:
raise RuntimeError(f"unsupported type {tp}")
return BaseType(TypeGen.convert_to_base_ty[tp])
class ReturnGen:
@staticmethod
def from_example(
name: Optional[str], obj: Any, annotation: Optional[Annotation]
) -> Return:
return Return(name, TypeGen.from_example(obj), annotation)
class ArgumentGen:
@staticmethod
def from_example(
name: str, obj: Any, default: Optional[str], annotation: Optional[Annotation]
) -> Argument:
return Argument(
name, TypeGen.from_example(obj), default=default, annotation=annotation
)
class FunctionSchemaGen:
@staticmethod
def from_example(
op_name: str,
example_inputs: Tuple[Tuple[str, Any], ...],
example_outputs: Tuple[Any, ...],
) -> FunctionSchema:
args = []
for name, inp in example_inputs:
args.append(ArgumentGen.from_example(name, inp, None, None))
# ignore the annotations and other attributes for now, we could add more when needed.
arguments = Arguments(
tuple(), None, tuple(args), tuple(), None, tuple(), tuple()
)
returns = tuple(
ReturnGen.from_example(None, out, None) for out in example_outputs
)
op_name = OperatorName(BaseOperatorName(op_name, False, False, False), "")
return FunctionSchema(op_name, arguments, returns)

View File

@ -1880,7 +1880,9 @@ class BaseTy(Enum):
Storage = auto() Storage = auto()
Stream = auto() Stream = auto()
SymInt = auto() SymInt = auto()
SymBool = auto()
ConstQuantizerPtr = auto() # TODO: rename ConstQuantizerPtr = auto() # TODO: rename
GraphModule = auto()
@dataclass(frozen=True) @dataclass(frozen=True)