mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[HOP] support generating schema for hop (#133521)
Add a way of generating a FunctionSchema from example values because hop's schema varies even for the same hop. We didn't use torch._C.FunctionSchema because we cannot construct the classes directly (e.g. "__init__" cannot be used for torch._C.FunctionSchema). Also extending the Basic types in c++ seems not that easy. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133521 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
dd5a7c8397
commit
6835f20d20
@ -185,6 +185,15 @@ def is_training_ir_test(test_name):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_hop_schema(ep: torch.export.ExportedProgram):
|
||||||
|
hop_node = next(
|
||||||
|
node
|
||||||
|
for node in ep.graph.nodes
|
||||||
|
if isinstance(node.target, torch._ops.HigherOrderOperator)
|
||||||
|
)
|
||||||
|
return torch._library.utils.hop_schema_from_fx_node(hop_node)
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
||||||
class TestDynamismExpression(TestCase):
|
class TestDynamismExpression(TestCase):
|
||||||
def test_export_inline_constraints(self):
|
def test_export_inline_constraints(self):
|
||||||
@ -4181,6 +4190,11 @@ def forward(self, x):
|
|||||||
dim0 = torch.export.Dim("dim0", min=3)
|
dim0 = torch.export.Dim("dim0", min=3)
|
||||||
inp = torch.ones(6, 4)
|
inp = torch.ones(6, 4)
|
||||||
ep = export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0}})
|
ep = export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0}})
|
||||||
|
schema = get_hop_schema(ep)
|
||||||
|
self.assertExpectedInline(
|
||||||
|
str(schema),
|
||||||
|
"""cond(SymBool pred, GraphModule true_fn, GraphModule false_fn, Tensor[2] operands) -> Tensor[1]""",
|
||||||
|
)
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
ep.graph_module.code.strip(),
|
ep.graph_module.code.strip(),
|
||||||
"""\
|
"""\
|
||||||
@ -5570,6 +5584,11 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
|||||||
add = torch.ops.aten.add.Tensor(cos, getitem); cos = getitem = None
|
add = torch.ops.aten.add.Tensor(cos, getitem); cos = getitem = None
|
||||||
return (add,)""",
|
return (add,)""",
|
||||||
)
|
)
|
||||||
|
schema = get_hop_schema(ep)
|
||||||
|
self.assertExpectedInline(
|
||||||
|
str(schema),
|
||||||
|
"""cond(Tensor pred, GraphModule true_fn, GraphModule false_fn, Tensor[3] operands) -> Tensor[1]""",
|
||||||
|
)
|
||||||
|
|
||||||
cond_top_level_nn_module_stack = [
|
cond_top_level_nn_module_stack = [
|
||||||
node.meta["nn_module_stack"]
|
node.meta["nn_module_stack"]
|
||||||
@ -5655,6 +5674,12 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
|||||||
strict=False,
|
strict=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
schema = get_hop_schema(exported_program)
|
||||||
|
self.assertExpectedInline(
|
||||||
|
str(schema),
|
||||||
|
"""cond(Tensor pred, GraphModule true_fn, GraphModule false_fn, Tensor[3] operands) -> Tensor[1]""", # noqa: B950
|
||||||
|
)
|
||||||
|
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
str(exported_program.graph_module.code.strip()),
|
str(exported_program.graph_module.code.strip()),
|
||||||
"""\
|
"""\
|
||||||
@ -6139,6 +6164,12 @@ def forward(self, x):
|
|||||||
(torch.randn(4), torch.randn(4), torch.randn(4)),
|
(torch.randn(4), torch.randn(4), torch.randn(4)),
|
||||||
pre_dispatch=True,
|
pre_dispatch=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
schema = get_hop_schema(ep)
|
||||||
|
self.assertExpectedInline(
|
||||||
|
str(schema),
|
||||||
|
"""cond(Tensor pred, GraphModule true_fn, GraphModule false_fn, Tensor[3] operands) -> Tensor[1]""",
|
||||||
|
)
|
||||||
# test cond subgraph
|
# test cond subgraph
|
||||||
expected_names_and_ops = [
|
expected_names_and_ops = [
|
||||||
("mul_2", "placeholder"),
|
("mul_2", "placeholder"),
|
||||||
|
@ -3711,6 +3711,143 @@ def forward(self, l_inp_, l_tmp_):
|
|||||||
self.assertEqual(cnt.frame_count, 3)
|
self.assertEqual(cnt.frame_count, 3)
|
||||||
|
|
||||||
|
|
||||||
|
_hop_schema_test_schema_types = [
|
||||||
|
"bool",
|
||||||
|
"int",
|
||||||
|
"float",
|
||||||
|
"str",
|
||||||
|
"Tensor",
|
||||||
|
"SymInt",
|
||||||
|
"SymBool",
|
||||||
|
"GraphModule",
|
||||||
|
"ScriptObj",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
||||||
|
class TestHopSchema(TestCase):
|
||||||
|
def _get_example_val(self, ty: str):
|
||||||
|
from torch.fx.experimental.sym_node import SymNode
|
||||||
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||||
|
|
||||||
|
def create_symtype(cls, pytype, shape_env, val):
|
||||||
|
from torch._dynamo.source import ConstantSource
|
||||||
|
|
||||||
|
symbol = shape_env.create_symbol(
|
||||||
|
val,
|
||||||
|
source=ConstantSource(
|
||||||
|
f"__testing_hop_schema{len(shape_env.var_to_val)}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return cls(SymNode(symbol, shape_env, pytype, hint=val))
|
||||||
|
|
||||||
|
if ty == "bool":
|
||||||
|
return True
|
||||||
|
elif ty == "int":
|
||||||
|
return 1
|
||||||
|
elif ty == "float":
|
||||||
|
return 1.0
|
||||||
|
elif ty == "str":
|
||||||
|
return "foo"
|
||||||
|
elif ty == "Tensor":
|
||||||
|
return torch.tensor(1)
|
||||||
|
elif ty == "SymInt":
|
||||||
|
shape_env = ShapeEnv()
|
||||||
|
return create_symtype(torch.SymInt, int, shape_env, 1)
|
||||||
|
elif ty == "SymBool":
|
||||||
|
shape_env = ShapeEnv()
|
||||||
|
return create_symtype(torch.SymBool, bool, shape_env, True)
|
||||||
|
elif ty == "GraphModule":
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
return x.sin()
|
||||||
|
|
||||||
|
return make_fx(f)(torch.ones(1))
|
||||||
|
elif ty == "ScriptObj":
|
||||||
|
from torch.testing._internal.torchbind_impls import (
|
||||||
|
init_torchbind_implementations,
|
||||||
|
)
|
||||||
|
|
||||||
|
init_torchbind_implementations()
|
||||||
|
foo = torch.classes._TorchScriptTesting._Foo(3, 4)
|
||||||
|
return foo
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(ty)
|
||||||
|
|
||||||
|
@parametrize("schema_type", _hop_schema_test_schema_types)
|
||||||
|
def test_type_gen(self, schema_type):
|
||||||
|
from torchgen.gen_schema_utils import TypeGen
|
||||||
|
|
||||||
|
example_val = self._get_example_val(schema_type)
|
||||||
|
ty = TypeGen.from_example(example_val)
|
||||||
|
# Test the generated type can be parsed
|
||||||
|
self.assertEqual(ty.parse(str(ty)), ty)
|
||||||
|
|
||||||
|
@parametrize("schema_type", _hop_schema_test_schema_types)
|
||||||
|
def test_list_gen(self, schema_type):
|
||||||
|
from torchgen.gen_schema_utils import TypeGen
|
||||||
|
|
||||||
|
example_val = self._get_example_val(schema_type)
|
||||||
|
li1 = [example_val]
|
||||||
|
li2 = [example_val, example_val]
|
||||||
|
ty1 = TypeGen.from_example(li1)
|
||||||
|
ty2 = TypeGen.from_example(li1)
|
||||||
|
self.assertEqual(ty1.parse(str(ty1)), ty1)
|
||||||
|
self.assertEqual(ty2.parse(str(ty2)), ty2)
|
||||||
|
|
||||||
|
def test_function_schema_gen(self):
|
||||||
|
from torchgen.gen_schema_utils import FunctionSchemaGen
|
||||||
|
|
||||||
|
inps = [
|
||||||
|
(schema_type + "_v", self._get_example_val(schema_type))
|
||||||
|
for schema_type in _hop_schema_test_schema_types
|
||||||
|
]
|
||||||
|
op_name = "test_op"
|
||||||
|
schema1 = FunctionSchemaGen.from_example("test_op1", inps, torch.ones(1))
|
||||||
|
schema2 = FunctionSchemaGen.from_example(
|
||||||
|
"test_op2",
|
||||||
|
inps,
|
||||||
|
[
|
||||||
|
torch.ones(1),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
schema3 = FunctionSchemaGen.from_example(
|
||||||
|
"test_op3", inps, [torch.ones(1), torch.ones(1)]
|
||||||
|
)
|
||||||
|
self.assertExpectedInline(
|
||||||
|
str(schema1),
|
||||||
|
"""test_op1(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> Tensor""", # noqa: B950
|
||||||
|
)
|
||||||
|
self.assertExpectedInline(
|
||||||
|
str(schema2),
|
||||||
|
"""test_op2(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> Tensor""", # noqa: B950
|
||||||
|
)
|
||||||
|
self.assertExpectedInline(
|
||||||
|
str(schema3),
|
||||||
|
"""test_op3(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> (Tensor, Tensor)""", # noqa: B950,
|
||||||
|
)
|
||||||
|
self.assertEqual(schema1.parse(str(schema1)), schema1)
|
||||||
|
self.assertEqual(schema2.parse(str(schema2)), schema2)
|
||||||
|
self.assertEqual(schema3.parse(str(schema3)), schema3)
|
||||||
|
|
||||||
|
def test_while_loop_schema_gen(self):
|
||||||
|
fn, inp = WHILE_LOOP_TESTS["simple_with_linear"]
|
||||||
|
graph = make_fx(fn)(*inp).graph
|
||||||
|
while_loop_node = next(
|
||||||
|
node
|
||||||
|
for node in graph.nodes
|
||||||
|
if node.op == "call_function"
|
||||||
|
and node.target is torch.ops.higher_order.while_loop
|
||||||
|
)
|
||||||
|
schema = torch._library.utils.hop_schema_from_fx_node(while_loop_node)
|
||||||
|
self.assertExpectedInline(
|
||||||
|
str(schema),
|
||||||
|
"""while_loop(GraphModule cond_fn, GraphModule body_fn, Tensor[2] carried_inputs, Tensor[3] additional_inputs) -> Tensor[2]""", # noqa: B950
|
||||||
|
)
|
||||||
|
self.assertEqual(schema.parse(str(schema)), schema)
|
||||||
|
|
||||||
|
|
||||||
|
instantiate_parametrized_tests(TestHopSchema)
|
||||||
instantiate_parametrized_tests(TestControlFlowTraced)
|
instantiate_parametrized_tests(TestControlFlowTraced)
|
||||||
|
|
||||||
instantiate_parametrized_tests(TestControlFlow)
|
instantiate_parametrized_tests(TestControlFlow)
|
||||||
|
@ -37,6 +37,17 @@ def wrap_combine_fn_flat(*args, combine_fn, spec, num_leaves):
|
|||||||
return combined_leaves
|
return combined_leaves
|
||||||
|
|
||||||
|
|
||||||
|
class AssociativeScanOp(HigherOrderOperator):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("associative_scan")
|
||||||
|
|
||||||
|
def __call__(self, combine_fn, input, dim):
|
||||||
|
return super().__call__(combine_fn, input, dim)
|
||||||
|
|
||||||
|
|
||||||
|
associative_scan_op = AssociativeScanOp()
|
||||||
|
|
||||||
|
|
||||||
def associative_scan(
|
def associative_scan(
|
||||||
combine_fn: Callable[[pytree.PyTree, pytree.PyTree], pytree.PyTree],
|
combine_fn: Callable[[pytree.PyTree, pytree.PyTree], pytree.PyTree],
|
||||||
input: pytree.PyTree,
|
input: pytree.PyTree,
|
||||||
@ -110,9 +121,6 @@ def associative_scan(
|
|||||||
return pytree.tree_unflatten(result_flat, spec)
|
return pytree.tree_unflatten(result_flat, spec)
|
||||||
|
|
||||||
|
|
||||||
associative_scan_op = HigherOrderOperator("associative_scan")
|
|
||||||
|
|
||||||
|
|
||||||
def trace_associative_scan(
|
def trace_associative_scan(
|
||||||
proxy_mode, func_overload, combine_fn: Callable, input: List[torch.Tensor], dim: int
|
proxy_mode, func_overload, combine_fn: Callable, input: List[torch.Tensor], dim: int
|
||||||
):
|
):
|
||||||
|
@ -40,6 +40,22 @@ from .utils import _from_fun, create_fw_bw_graph
|
|||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
"""
|
||||||
|
We're going to define a `cond_op` operation.
|
||||||
|
In order to do this, we need implementations for each of the dispatch keys.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class CondOp(HigherOrderOperator):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("cond")
|
||||||
|
|
||||||
|
def __call__(self, pred, true_fn, false_fn, operands):
|
||||||
|
return super().__call__(pred, true_fn, false_fn, operands)
|
||||||
|
|
||||||
|
|
||||||
|
cond_op = CondOp()
|
||||||
|
|
||||||
|
|
||||||
@exposed_in("torch")
|
@exposed_in("torch")
|
||||||
def cond(pred, true_fn, false_fn, operands):
|
def cond(pred, true_fn, false_fn, operands):
|
||||||
@ -160,13 +176,6 @@ def cond(pred, true_fn, false_fn, operands):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
We're going to define a `cond_op` operation.
|
|
||||||
In order to do this, we need implementations for each of the dispatch keys.
|
|
||||||
"""
|
|
||||||
cond_op = HigherOrderOperator("cond")
|
|
||||||
|
|
||||||
|
|
||||||
def create_fw_bw_graph_branches(true_fn, false_fn, *operands):
|
def create_fw_bw_graph_branches(true_fn, false_fn, *operands):
|
||||||
# See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py
|
# See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py
|
||||||
|
|
||||||
|
@ -201,6 +201,44 @@ def zip_schema(
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def hop_schema_from_fx_node(node):
|
||||||
|
from torchgen.gen_schema_utils import FunctionSchemaGen
|
||||||
|
|
||||||
|
hop = node.target
|
||||||
|
if not isinstance(hop, torch._ops.HigherOrderOperator):
|
||||||
|
raise RuntimeError("fx_node's target must be a hop.")
|
||||||
|
|
||||||
|
def _collect_example_val(node):
|
||||||
|
meta_val = node.meta.get("val", None)
|
||||||
|
if meta_val is None:
|
||||||
|
assert node.op == "get_attr"
|
||||||
|
meta_val = getattr(node.graph.owning_module, node.target)
|
||||||
|
return meta_val
|
||||||
|
|
||||||
|
example_inputs = []
|
||||||
|
for arg in node.args:
|
||||||
|
if isinstance(arg, (torch.fx.Node, torch.fx.node.Node)):
|
||||||
|
example_inputs.append(_collect_example_val(arg))
|
||||||
|
elif isinstance(
|
||||||
|
arg, (torch.fx.immutable_collections.immutable_list, list, tuple)
|
||||||
|
):
|
||||||
|
example_inputs.append([_collect_example_val(x) for x in arg])
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported arg type {type(arg)}")
|
||||||
|
|
||||||
|
# Bound the arguments to make sure number of inputs are correct
|
||||||
|
bound_args: inspect.BoundArguments = inspect.signature(hop.__call__).bind(
|
||||||
|
*example_inputs
|
||||||
|
)
|
||||||
|
|
||||||
|
# We treat example_output as a single value in return. This is to differentiate 1. return a single val
|
||||||
|
# vs 2. return a tuple with one element.
|
||||||
|
example_output = _collect_example_val(node)
|
||||||
|
return FunctionSchemaGen.from_example(
|
||||||
|
hop._name, tuple(bound_args.arguments.items()), (list(example_output),)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def can_generate_trivial_fake_impl(op: OpOverload) -> bool:
|
def can_generate_trivial_fake_impl(op: OpOverload) -> bool:
|
||||||
assert isinstance(op, OpOverload)
|
assert isinstance(op, OpOverload)
|
||||||
if is_builtin(op):
|
if is_builtin(op):
|
||||||
|
97
torchgen/gen_schema_utils.py
Normal file
97
torchgen/gen_schema_utils.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
from typing import Any, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from torchgen.model import (
|
||||||
|
Annotation,
|
||||||
|
Argument,
|
||||||
|
Arguments,
|
||||||
|
BaseOperatorName,
|
||||||
|
BaseTy,
|
||||||
|
BaseType,
|
||||||
|
CustomClassType,
|
||||||
|
FunctionSchema,
|
||||||
|
ListType,
|
||||||
|
OperatorName,
|
||||||
|
Return,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Note: These aren't actually used in torchgen, they're some utilities for generating a schema
|
||||||
|
# from real arguments. For example, this is used to generate HigherOrderOperators' schema since
|
||||||
|
# their schemas can vary for different instances of the same HOP.
|
||||||
|
|
||||||
|
|
||||||
|
class TypeGen:
|
||||||
|
convert_to_base_ty = {
|
||||||
|
int: BaseTy.int,
|
||||||
|
float: BaseTy.float,
|
||||||
|
str: BaseTy.str,
|
||||||
|
bool: BaseTy.bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_example(obj: Any) -> Union[BaseType, ListType, CustomClassType]:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if isinstance(obj, torch.fx.GraphModule):
|
||||||
|
return BaseType(BaseTy.GraphModule)
|
||||||
|
elif isinstance(obj, torch.Tensor):
|
||||||
|
return BaseType(BaseTy.Tensor)
|
||||||
|
elif isinstance(obj, torch.SymInt):
|
||||||
|
return BaseType(BaseTy.SymInt)
|
||||||
|
elif isinstance(obj, torch.SymBool):
|
||||||
|
return BaseType(BaseTy.SymBool)
|
||||||
|
elif isinstance(obj, torch.ScriptObject):
|
||||||
|
return CustomClassType(obj._type().name()) # type: ignore[attr-defined]
|
||||||
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
assert len(obj) > 0
|
||||||
|
all_base_tys = [TypeGen.from_example(x) for x in obj]
|
||||||
|
if len(set(all_base_tys)) > 1:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot generate schema for a seqeunce of args of heterogeneous types: {all_base_tys}. "
|
||||||
|
"Consider unpacking the argument and give proper names to them if possible "
|
||||||
|
"instead of using *args."
|
||||||
|
)
|
||||||
|
return ListType(all_base_tys[0], len(obj))
|
||||||
|
tp = type(obj)
|
||||||
|
if tp not in TypeGen.convert_to_base_ty:
|
||||||
|
raise RuntimeError(f"unsupported type {tp}")
|
||||||
|
return BaseType(TypeGen.convert_to_base_ty[tp])
|
||||||
|
|
||||||
|
|
||||||
|
class ReturnGen:
|
||||||
|
@staticmethod
|
||||||
|
def from_example(
|
||||||
|
name: Optional[str], obj: Any, annotation: Optional[Annotation]
|
||||||
|
) -> Return:
|
||||||
|
return Return(name, TypeGen.from_example(obj), annotation)
|
||||||
|
|
||||||
|
|
||||||
|
class ArgumentGen:
|
||||||
|
@staticmethod
|
||||||
|
def from_example(
|
||||||
|
name: str, obj: Any, default: Optional[str], annotation: Optional[Annotation]
|
||||||
|
) -> Argument:
|
||||||
|
return Argument(
|
||||||
|
name, TypeGen.from_example(obj), default=default, annotation=annotation
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionSchemaGen:
|
||||||
|
@staticmethod
|
||||||
|
def from_example(
|
||||||
|
op_name: str,
|
||||||
|
example_inputs: Tuple[Tuple[str, Any], ...],
|
||||||
|
example_outputs: Tuple[Any, ...],
|
||||||
|
) -> FunctionSchema:
|
||||||
|
args = []
|
||||||
|
for name, inp in example_inputs:
|
||||||
|
args.append(ArgumentGen.from_example(name, inp, None, None))
|
||||||
|
# ignore the annotations and other attributes for now, we could add more when needed.
|
||||||
|
arguments = Arguments(
|
||||||
|
tuple(), None, tuple(args), tuple(), None, tuple(), tuple()
|
||||||
|
)
|
||||||
|
returns = tuple(
|
||||||
|
ReturnGen.from_example(None, out, None) for out in example_outputs
|
||||||
|
)
|
||||||
|
op_name = OperatorName(BaseOperatorName(op_name, False, False, False), "")
|
||||||
|
return FunctionSchema(op_name, arguments, returns)
|
@ -1880,7 +1880,9 @@ class BaseTy(Enum):
|
|||||||
Storage = auto()
|
Storage = auto()
|
||||||
Stream = auto()
|
Stream = auto()
|
||||||
SymInt = auto()
|
SymInt = auto()
|
||||||
|
SymBool = auto()
|
||||||
ConstQuantizerPtr = auto() # TODO: rename
|
ConstQuantizerPtr = auto() # TODO: rename
|
||||||
|
GraphModule = auto()
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
Reference in New Issue
Block a user